diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b7ba0214..6e75ae7e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,8 +16,14 @@ jobs: do_not_skip: '["pull_request"]' cancel_others: 'true' concurrent_skipping: same_content - test: + ruff: + runs-on: ubuntu-latest needs: pre_job + steps: + - uses: actions/checkout@v4 + - uses: chartboost/ruff-action@v1 + test: + needs: ruff if: ${{ needs.pre_job.outputs.should_skip != 'true' }} runs-on: ubuntu-latest strategy: @@ -43,6 +49,6 @@ jobs: run: | poetry run pytest -vvv -ra --cov=cryptojwt --cov-report=xml --isort --black - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..1959583d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-merge-conflict + - id: debug-statements + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-json + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.9 + hooks: + - id: ruff + - id: ruff-format diff --git a/doc/conf.py b/doc/conf.py old mode 100644 new mode 100755 index 28cd9204..8b9d1745 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# +# ruff: noqa # -*- coding: utf-8 -*- # # CryptoJWT documentation build configuration file, created by diff --git a/pyproject.toml b/pyproject.toml index 4935e69c..39dfdeb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ python = "^3.9" cryptography = ">=3.4.6" requests = "^2.25.1" -[tool.poetry.dev-dependencies] +[tool.poetry.group.dev.dependencies] alabaster = "^0.7.12" black = "^24.4.2" isort = "^5.13.2" @@ -54,7 +54,27 @@ responses = "^0.13.0" sphinx = "^3.5.2" sphinx-autobuild = "^2021.3.14" coverage = "^7" +ruff = "^0.4.6" +pytest-ruff = "^0.3.2" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", +] +ignore = ["E501", "I001", "SIM102"] +exclude = ["examples/*"] diff --git a/src/cryptojwt/__init__.py b/src/cryptojwt/__init__.py index 38787f07..5938c11e 100644 --- a/src/cryptojwt/__init__.py +++ b/src/cryptojwt/__init__.py @@ -16,15 +16,23 @@ from .utils import b64encode_item from .utils import split_token -try: - from builtins import hex - from builtins import str - from builtins import zip -except ImportError: - pass - __version__ = version("cryptojwt") +__all__ = [ + "JWE", + "JWE", + "JWK", + "JWS", + "JWT", + "KeyBundle", + "KeyJar", + "BadSyntax", + "as_unicode", + "b64d", + "b64encode_item", + "split_token", +] + logger = logging.getLogger(__name__) JWT_TYPES = ("JWT", "application/jws", "JWS", "JWE") diff --git a/src/cryptojwt/exception.py b/src/cryptojwt/exception.py index 83bb1d35..63c04012 100644 --- a/src/cryptojwt/exception.py +++ b/src/cryptojwt/exception.py @@ -20,7 +20,7 @@ def __init__(self, value, msg): self.msg = msg def __str__(self): - return "%s: %r" % (self.msg, self.value) + return f"{self.msg}: {self.value!r}" class BadSignature(Invalid): diff --git a/src/cryptojwt/jwe/__init__.py b/src/cryptojwt/jwe/__init__.py index ffddcf13..c0f0d658 100644 --- a/src/cryptojwt/jwe/__init__.py +++ b/src/cryptojwt/jwe/__init__.py @@ -38,7 +38,7 @@ } -class Encrypter(object): +class Encrypter: """Abstract base class for encryption algorithms.""" def __init__(self, with_digest=False): diff --git a/src/cryptojwt/jwe/aes.py b/src/cryptojwt/jwe/aes.py index c34ab459..502902e4 100644 --- a/src/cryptojwt/jwe/aes.py +++ b/src/cryptojwt/jwe/aes.py @@ -30,7 +30,7 @@ def __init__(self, key_len=32, key=None, msg_padding="PKCS7"): self.padder = PKCS7(128).padder() self.unpadder = PKCS7(128).unpadder() else: - raise Unsupported("Message padding: {}".format(msg_padding)) + raise Unsupported(f"Message padding: {msg_padding}") self.iv = None diff --git a/src/cryptojwt/jwe/jwe.py b/src/cryptojwt/jwe/jwe.py index 5935c9b6..ad11244f 100644 --- a/src/cryptojwt/jwe/jwe.py +++ b/src/cryptojwt/jwe/jwe.py @@ -1,14 +1,13 @@ +import contextlib import logging from ..jwk.asym import AsymmetricKey from ..jwk.ec import ECKey from ..jwk.hmac import SYMKey from ..jwk.jwk import key_from_jwk_dict -from ..jwk.rsa import RSAKey from ..jwx import JWx from .exception import DecryptionFailed from .exception import NoSuitableDecryptionKey -from .exception import NoSuitableECDHKey from .exception import NoSuitableEncryptionKey from .exception import NotSupportedAlgorithm from .exception import WrongEncryptionAlgorithm @@ -133,7 +132,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs): except TypeError as err: raise err else: - logger.debug("Encrypted message using key with kid={}".format(key.kid)) + logger.debug(f"Encrypted message using key with kid={key.kid}") return token # logger.error("Could not find any suitable encryption key") @@ -159,10 +158,8 @@ def decrypt(self, token=None, keys=None, alg=None, cek=None): else: keys = self.pick_keys(self._get_keys(), use="enc", alg=_alg) - try: + with contextlib.suppress(KeyError): keys.append(key_from_jwk_dict(_jwe.headers["jwk"])) - except KeyError: - pass if not keys and not cek: raise NoSuitableDecryptionKey(_alg) @@ -194,10 +191,7 @@ def decrypt(self, token=None, keys=None, alg=None, cek=None): return msg for key in keys: - if isinstance(key, AsymmetricKey): - _key = key.private_key() - else: - _key = key.key + _key = key.private_key() if isinstance(key, AsymmetricKey) else key.key try: msg = decrypter.decrypt(_jwe, _key) @@ -205,7 +199,7 @@ def decrypt(self, token=None, keys=None, alg=None, cek=None): except (KeyError, DecryptionFailed): pass else: - logger.debug("Decrypted message using key with kid=%s" % key.kid) + logger.debug(f"Decrypted message using key with kid={key.kid}") return msg raise DecryptionFailed("No available key that could decrypt the message") diff --git a/src/cryptojwt/jwe/jwe_ec.py b/src/cryptojwt/jwe/jwe_ec.py index bcf81059..db2d2e18 100644 --- a/src/cryptojwt/jwe/jwe_ec.py +++ b/src/cryptojwt/jwe/jwe_ec.py @@ -1,3 +1,4 @@ +import contextlib import struct from cryptography.hazmat.primitives.asymmetric import ec @@ -110,8 +111,8 @@ def enc_setup(self, msg, key=None, auth_data=b"", **kwargs): if self.alg == "ECDH-ES": try: dk_len = KEY_LEN[self.enc] - except KeyError: - raise ValueError("Unknown key length for algorithm %s" % self.enc) + except KeyError as exc: + raise ValueError(f"Unknown key length for algorithm {self.enc}") from exc cek = ecdh_derive_key(_epk, key.pub_key, apu, apv, str(self.enc).encode(), dk_len) elif self.alg in ["ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"]: @@ -121,7 +122,7 @@ def enc_setup(self, msg, key=None, auth_data=b"", **kwargs): cek = self._generate_key(self.enc, cek=cek) encrypted_key = aes_key_wrap(kek, cek) else: - raise Exception("Unsupported algorithm %s" % self.alg) + raise Exception(f"Unsupported algorithm {self.alg}") return cek, encrypted_key, iv, params, epk @@ -152,8 +153,8 @@ def dec_setup(self, token, key=None, **kwargs): if self.headers["alg"] == "ECDH-ES": try: dk_len = KEY_LEN[self.headers["enc"]] - except KeyError: - raise Exception("Unknown key length for algorithm") + except KeyError as exc: + raise Exception("Unknown key length for algorithm") from exc self.cek = ecdh_derive_key( key, @@ -173,7 +174,7 @@ def dec_setup(self, token, key=None, **kwargs): kek = ecdh_derive_key(key, epubkey.pub_key, apu, apv, str(_post).encode(), klen) self.cek = aes_key_unwrap(kek, token.encrypted_key()) else: - raise Exception("Unsupported algorithm %s" % self.headers["alg"]) + raise Exception("Unsupported algorithm {}".format(self.headers["alg"])) return self.cek @@ -190,10 +191,8 @@ def encrypt(self, key=None, iv="", cek="", **kwargs): _msg = as_bytes(self.msg) _args = self._dict - try: + with contextlib.suppress(KeyError): _args["kid"] = kwargs["kid"] - except KeyError: - pass if "params" in kwargs: if "apu" in kwargs["params"]: @@ -204,7 +203,7 @@ def encrypt(self, key=None, iv="", cek="", **kwargs): _args["epk"] = kwargs["params"]["epk"] jwe = JWEnc(**_args) - ctxt, tag, cek = super(JWE_EC, self).enc_setup( + ctxt, tag, cek = super().enc_setup( self["enc"], _msg, auth_data=jwe.b64_encode_header(), key=cek, iv=iv ) if "encrypted_key" in kwargs: @@ -212,15 +211,12 @@ def encrypt(self, key=None, iv="", cek="", **kwargs): return jwe.pack(parts=[iv, ctxt, tag]) def decrypt(self, token=None, **kwargs): - if isinstance(token, JWEnc): - jwe = token - else: - jwe = JWEnc().unpack(token) + jwe = token if isinstance(token, JWEnc) else JWEnc().unpack(token) if not self.cek: raise Exception("Content Encryption Key is Not Yet Set") - msg = super(JWE_EC, self)._decrypt( + msg = super()._decrypt( self.headers["enc"], self.cek, self.ctxt, diff --git a/src/cryptojwt/jwe/jwe_hmac.py b/src/cryptojwt/jwe/jwe_hmac.py index ae5d010b..3a1d1ed3 100644 --- a/src/cryptojwt/jwe/jwe_hmac.py +++ b/src/cryptojwt/jwe/jwe_hmac.py @@ -1,3 +1,4 @@ +import contextlib import logging import zlib @@ -34,10 +35,8 @@ def encrypt(self, key, iv="", cek="", **kwargs): _msg = as_bytes(self.msg) _args = self._dict - try: + with contextlib.suppress(KeyError): _args["kid"] = kwargs["kid"] - except KeyError: - pass jwe = JWEnc(**_args) @@ -68,10 +67,7 @@ def decrypt(self, token, key=None, cek=None): if not key and not cek: raise MissingKey("On of key or cek must be specified") - if isinstance(token, JWEnc): - jwe = token - else: - jwe = JWEnc().unpack(token) + jwe = token if isinstance(token, JWEnc) else JWEnc().unpack(token) if len(jwe) != 5: raise WrongNumberOfParts(len(jwe)) diff --git a/src/cryptojwt/jwe/jwe_rsa.py b/src/cryptojwt/jwe/jwe_rsa.py index f34b1331..51965d3d 100644 --- a/src/cryptojwt/jwe/jwe_rsa.py +++ b/src/cryptojwt/jwe/jwe_rsa.py @@ -49,7 +49,7 @@ def encrypt(self, key, iv="", cek="", **kwargs): if self["zip"] == "DEF": _msg = zlib.compress(_msg) else: - raise ParameterError("Zip has unknown value: %s" % self["zip"]) + raise ParameterError("Zip has unknown value: {}".format(self["zip"])) kwarg_cek = cek or None @@ -58,7 +58,7 @@ def encrypt(self, key, iv="", cek="", **kwargs): cek = self._generate_key(_enc, cek) self["cek"] = cek - logger.debug("cek: %s, iv: %s" % ([c for c in cek], [c for c in iv])) + logger.debug(f"cek: {[c for c in cek]}, iv: {[c for c in iv]}") _encrypt = RSAEncrypter(self.with_digest).encrypt @@ -92,10 +92,7 @@ def decrypt(self, token, key, cek=None): :param cek: Ephemeral cipher key :return: The decrypted message """ - if not isinstance(token, JWEnc): - jwe = JWEnc().unpack(token) - else: - jwe = token + jwe = JWEnc().unpack(token) if not isinstance(token, JWEnc) else token self.jwt = jwe.encrypted_key() jek = jwe.encrypted_key() diff --git a/src/cryptojwt/jwe/jwekey.py b/src/cryptojwt/jwe/jwekey.py index 31a1c8ae..719e931f 100644 --- a/src/cryptojwt/jwe/jwekey.py +++ b/src/cryptojwt/jwe/jwekey.py @@ -29,8 +29,8 @@ def _generate_key(encalg, cek=""): except KeyError: try: _key = get_random_bytes(KEY_LEN_BYTES[encalg]) - except KeyError: - raise ValueError("Unsupported encryption algorithm %s" % encalg) + except KeyError as exc: + raise ValueError(f"Unsupported encryption algorithm {encalg}") from exc return _key @@ -77,7 +77,7 @@ def _decrypt(enc, key, ctxt, iv, tag, auth_data=b""): elif enc in ["A128CBC-HS256", "A192CBC-HS384", "A256CBC-HS512"]: aes = AES_CBCEncrypter(key=key) else: - raise Exception("Unsupported encryption algorithm %s" % enc) + raise Exception(f"Unsupported encryption algorithm {enc}") try: return aes.decrypt(ctxt, iv=iv, auth_data=auth_data, tag=tag) diff --git a/src/cryptojwt/jwe/jwenc.py b/src/cryptojwt/jwe/jwenc.py index 232ff0b6..61ef6439 100644 --- a/src/cryptojwt/jwe/jwenc.py +++ b/src/cryptojwt/jwe/jwenc.py @@ -48,7 +48,7 @@ def is_jwe(self): if "alg" in self.headers and "enc" in self.headers: for typ in ["alg", "enc"]: if self.headers[typ] not in SUPPORTED[typ]: - logger.debug("Not supported %s algorithm: %s" % (typ, self.headers[typ])) + logger.debug(f"Not supported {typ} algorithm: {self.headers[typ]}") return False else: return False diff --git a/src/cryptojwt/jwe/utils.py b/src/cryptojwt/jwe/utils.py index 2ee69b98..f1c44aa0 100644 --- a/src/cryptojwt/jwe/utils.py +++ b/src/cryptojwt/jwe/utils.py @@ -7,8 +7,6 @@ from cryptography.hazmat.primitives.hashes import SHA384 from cryptography.hazmat.primitives.hashes import SHA512 -from ..utils import b64e - LENMET = {32: (16, SHA256), 48: (24, SHA384), 64: (32, SHA512)} @@ -20,8 +18,8 @@ def get_keys_seclen_dgst(key, iv): # Select the digest to use based on key length try: seclen, hash_method = LENMET[len(key)] - except KeyError: - raise Exception("Invalid CBC+HMAC key length: %s bytes" % len(key)) + except KeyError as exc: + raise Exception(f"Invalid CBC+HMAC key length: {len(key)} bytes") from exc # Split the key ka = key[:seclen] diff --git a/src/cryptojwt/jwk/__init__.py b/src/cryptojwt/jwk/__init__.py index e17f5a1a..070b56e4 100644 --- a/src/cryptojwt/jwk/__init__.py +++ b/src/cryptojwt/jwk/__init__.py @@ -14,7 +14,7 @@ USE = {"sign": "sig", "decrypt": "enc", "encrypt": "enc", "verify": "sig"} -class JWK(object): +class JWK: """ Basic JSON Web key class. Jason Web keys are described in RFC 7517 (https://tools.ietf.org/html/rfc7517). @@ -56,7 +56,7 @@ def __init__( "ECDH-ES+A192KW", "ECDH-ES+A256KW", ]: - raise UnsupportedAlgorithm("Unknown algorithm: {}".format(alg)) + raise UnsupportedAlgorithm(f"Unknown algorithm: {alg}") elif use == "sig": # The list comes from https://tools.ietf.org/html/rfc7518#page-6 # Should map against SIGNER_ALGS in cryptojwt.jws.jws @@ -79,7 +79,7 @@ def __init__( "Ed448", "none", ]: - raise UnsupportedAlgorithm("Unknown algorithm: {}".format(alg)) + raise UnsupportedAlgorithm(f"Unknown algorithm: {alg}") else: # potentially used both for encryption and signing if alg not in [ "HS256", @@ -110,7 +110,7 @@ def __init__( "ECDH-ES+A192KW", "ECDH-ES+A256KW", ]: - raise UnsupportedAlgorithm("Unknown algorithm: {}".format(alg)) + raise UnsupportedAlgorithm(f"Unknown algorithm: {alg}") self.alg = alg if isinstance(use, str): @@ -271,7 +271,7 @@ def thumbprint(self, hash_function, members=None): else: if isinstance(_val, bytes): _val = as_unicode(_val) - _se.append('"{}":{}'.format(elem, json.dumps(_val))) + _se.append(f'"{elem}":{json.dumps(_val)}') _json = "{{{}}}".format(",".join(_se)) return b64e(DIGEST_HASH[hash_function](_json)) @@ -298,7 +298,7 @@ def update(self): pass def key_len(self): - raise NotImplemented + raise NotImplementedError def pems_to_x5c(cert_chain): @@ -360,7 +360,7 @@ def certificate_fingerprint(der, hash="sha256"): def pem_hash(pem_file): - with open(pem_file, "r") as fp: + with open(pem_file) as fp: pem = fp.read() der = ssl.PEM_cert_to_DER_cert(pem) diff --git a/src/cryptojwt/jwk/asym.py b/src/cryptojwt/jwk/asym.py index 5ce6af94..1930248c 100644 --- a/src/cryptojwt/jwk/asym.py +++ b/src/cryptojwt/jwk/asym.py @@ -38,8 +38,8 @@ def appropriate_for(self, usage, **kwargs): """ try: _use = USE[usage] - except KeyError: - raise ValueError("Unknown key usage") + except KeyError as exc: + raise ValueError("Unknown key usage") from exc else: if usage in ["sign", "decrypt"]: if not self.use or _use == self.use: @@ -58,10 +58,7 @@ def has_private_key(self): :return: True/False """ - if self.priv_key: - return True - else: - return False + return bool(self.priv_key) def public_key(self): """ diff --git a/src/cryptojwt/jwk/ec.py b/src/cryptojwt/jwk/ec.py index 0a6757a7..eedff4a5 100644 --- a/src/cryptojwt/jwk/ec.py +++ b/src/cryptojwt/jwk/ec.py @@ -45,8 +45,8 @@ def ec_construct_public(num): """ try: _sec_crv = NIST2SEC[as_unicode(num["crv"])] - except KeyError: - raise UnsupportedECurve("Unsupported elliptic curve: {}".format(num["crv"])) + except KeyError as exc: + raise UnsupportedECurve("Unsupported elliptic curve: {}".format(num["crv"])) from exc ecpn = ec.EllipticCurvePublicNumbers(num["x"], num["y"], _sec_crv()) return ecpn.public_key() @@ -152,8 +152,8 @@ def deserialize(self): {"x": _x, "y": _y, "crv": self.crv, "d": _d} ) self.pub_key = self.priv_key.public_key() - except ValueError as err: - raise DeSerializationNotPossible(str(err)) + except ValueError as exc: + raise DeSerializationNotPossible(str(exc)) from exc else: self.pub_key = ec_construct_public({"x": _x, "y": _y, "crv": self.crv}) @@ -249,10 +249,8 @@ def __eq__(self, other): if other.private_key(): if cmp_keys(self.priv_key, other.priv_key, ec.EllipticCurvePrivateKey): return True - elif self.private_key(): - return False else: - return True + return not self.private_key() return False @@ -339,5 +337,5 @@ def import_ec_key(pem_data): def import_ec_key_from_cert_file(pem_file): - with open(pem_file, "r") as cert_file: + with open(pem_file) as cert_file: return import_ec_key(cert_file.read()) diff --git a/src/cryptojwt/jwk/hmac.py b/src/cryptojwt/jwk/hmac.py index f86c366d..dc166cb5 100644 --- a/src/cryptojwt/jwk/hmac.py +++ b/src/cryptojwt/jwk/hmac.py @@ -81,8 +81,8 @@ def appropriate_for(self, usage, alg="HS256"): """ try: _use = USE[usage] - except: - raise ValueError("Unknown key usage") + except Exception as exc: + raise ValueError("Unknown key usage") from exc else: if not self.use or self.use == _use: if _use == "sig": @@ -90,7 +90,7 @@ def appropriate_for(self, usage, alg="HS256"): else: return self.encryption_key(alg) - raise WrongUsage("This key can't be used for {}".format(usage)) + raise WrongUsage(f"This key can't be used for {usage}") def encryption_key(self, alg, **kwargs): """ @@ -106,8 +106,8 @@ def encryption_key(self, alg, **kwargs): try: tsize = ALG2KEYLEN[alg] - except KeyError: - raise UnsupportedAlgorithm(alg) + except KeyError as exc: + raise UnsupportedAlgorithm(alg) from exc if tsize <= 32: # SHA256 @@ -121,7 +121,7 @@ def encryption_key(self, alg, **kwargs): else: raise JWKException("No support for symmetric keys > 512 bits") - logger.debug("Symmetric encryption key: {}".format(as_unicode(b64e(_enc_key)))) + logger.debug(f"Symmetric encryption key: {as_unicode(b64e(_enc_key))}") return _enc_key diff --git a/src/cryptojwt/jwk/jwk.py b/src/cryptojwt/jwk/jwk.py index 0e10a95e..47bba714 100644 --- a/src/cryptojwt/jwk/jwk.py +++ b/src/cryptojwt/jwk/jwk.py @@ -74,7 +74,7 @@ def ensure_params(kty, provided, required): """Ensure all required parameters are present in dictionary""" if not required <= provided: missing = required - provided - raise MissingValue("Missing properties for kty={}, {}".format(kty, str(list(missing)))) + raise MissingValue(f"Missing properties for kty={kty}, {str(list(missing))}") def key_from_jwk_dict(jwk_dict, private=None): @@ -100,7 +100,7 @@ def key_from_jwk_dict(jwk_dict, private=None): if _jwk_dict["crv"] in NIST2SEC: curve = NIST2SEC[_jwk_dict["crv"]]() else: - raise UnsupportedAlgorithm("Unknown curve: %s" % (_jwk_dict["crv"])) + raise UnsupportedAlgorithm("Unknown curve: {}".format(_jwk_dict["crv"])) if _jwk_dict.get("d", None) is not None: # Ecdsa private key. @@ -183,7 +183,7 @@ def jwk_wrap(key, use="", kid=""): :param kid: A key id :return: The Key instance """ - if isinstance(key, rsa.RSAPublicKey) or isinstance(key, rsa.RSAPrivateKey): + if isinstance(key, (rsa.RSAPublicKey, rsa.RSAPrivateKey)): kspec = RSAKey(use=use, kid=kid).load_key(key) elif isinstance(key, str): kspec = SYMKey(key=key, use=use, kid=kid) diff --git a/src/cryptojwt/jwk/okp.py b/src/cryptojwt/jwk/okp.py index 83159629..f58f71a2 100644 --- a/src/cryptojwt/jwk/okp.py +++ b/src/cryptojwt/jwk/okp.py @@ -145,16 +145,16 @@ def deserialize(self): if isinstance(self.d, (str, bytes)): try: self.priv_key = OKP_CRV2PRIVATE[self.crv].from_private_bytes(deser(self.d)) - except KeyError: - raise UnsupportedOKPCurve("Unsupported OKP curve: {}".format(self.crv)) + except KeyError as exc: + raise UnsupportedOKPCurve(f"Unsupported OKP curve: {self.crv}") from exc self.pub_key = self.priv_key.public_key() - except ValueError as err: - raise DeSerializationNotPossible(str(err)) + except ValueError as exc: + raise DeSerializationNotPossible(str(exc)) from exc else: try: self.pub_key = OKP_CRV2PUBLIC[self.crv].from_public_bytes(_x) - except KeyError: - raise UnsupportedOKPCurve("Unsupported OKP curve: {}".format(self.crv)) + except KeyError as exc: + raise UnsupportedOKPCurve(f"Unsupported OKP curve: {self.crv}") from exc def _serialize_public(self, key): self.x = b64e( @@ -278,10 +278,8 @@ def __eq__(self, other): if other.private_key(): if cmp_keys(self.priv_key, other.priv_key, _private_cls): return True - elif self.private_key(): - return False else: - return True + return not self.private_key() return False @@ -376,5 +374,5 @@ def import_okp_key(pem_data): def import_okp_key_from_cert_file(pem_file): - with open(pem_file, "r") as cert_file: + with open(pem_file) as cert_file: return import_okp_key(cert_file.read()) diff --git a/src/cryptojwt/jwk/rsa.py b/src/cryptojwt/jwk/rsa.py index e4290e93..14f88c73 100644 --- a/src/cryptojwt/jwk/rsa.py +++ b/src/cryptojwt/jwk/rsa.py @@ -105,7 +105,7 @@ def import_rsa_key(pem_data): def import_rsa_key_from_cert_file(pem_file): - with open(pem_file, "r") as cert_file: + with open(pem_file) as cert_file: return import_rsa_key(cert_file.read()) @@ -117,13 +117,7 @@ def rsa_eq(key1, key2): :param key2: :return: """ - pn1 = key1.public_numbers() - pn2 = key2.public_numbers() - # Check if two RSA keys are in fact the same - if pn1 == pn2: - return True - else: - return False + return key1.public_numbers() == key2.public_numbers() def x509_rsa_load(txt): @@ -208,10 +202,7 @@ def cmp_private_numbers(pn1, pn2): if not cmp_public_numbers(pn1.public_numbers, pn2.public_numbers): return False - for param in ["d", "p", "q"]: - if getattr(pn1, param) != getattr(pn2, param): - return False - return True + return all(getattr(pn1, param) == getattr(pn2, param) for param in ["d", "p", "q"]) class RSAKey(AsymmetricKey): @@ -284,9 +275,7 @@ def __init__( self.pub_key = self.priv_key.public_key() elif self.pub_key: self._serialize(self.pub_key) - elif has_public_key_parts: - self.deserialize() - elif has_x509_cert_chain: + elif has_public_key_parts or has_x509_cert_chain: self.deserialize() elif not self.n and not self.e: pass @@ -324,8 +313,8 @@ def deserialize(self): self.pub_key = self.priv_key.public_key() else: self.pub_key = rsa_construct_public(numbers) - except ValueError as err: - raise DeSerializationNotPossible("%s" % err) + except ValueError as exc: + raise DeSerializationNotPossible(str(exc)) from exc if self.x5c: _cert_chain = [] diff --git a/src/cryptojwt/jwk/wrap.py b/src/cryptojwt/jwk/wrap.py index 95aa7e8a..2cbdac58 100644 --- a/src/cryptojwt/jwk/wrap.py +++ b/src/cryptojwt/jwk/wrap.py @@ -19,8 +19,8 @@ def wrap_key(key: JWK, wrapping_key: JWK, wrap_params: dict = DEFAULT_WRAP_PARAM message = json.dumps(key.serialize(private=True)).encode() try: enc_params = wrap_params[wrapping_key.kty] - except KeyError: - raise ValueError("Unsupported wrapping key type") + except KeyError as exc: + raise ValueError("Unsupported wrapping key type") from exc _jwe = JWE(msg=message, **enc_params) return _jwe.encrypt(keys=[wrapping_key], kid=wrapping_key.kid) diff --git a/src/cryptojwt/jwk/x509.py b/src/cryptojwt/jwk/x509.py index d1fd885c..c08bd63b 100644 --- a/src/cryptojwt/jwk/x509.py +++ b/src/cryptojwt/jwk/x509.py @@ -50,7 +50,7 @@ def import_public_key_from_pem_data(pem_data): :return: rsa.RSAPublicKey instance """ if not pem_data.startswith(PREFIX): - pem_data = bytes("{}\n{}\n{}".format(PREFIX, pem_data, POSTFIX), "utf-8") + pem_data = bytes(f"{PREFIX}\n{pem_data}\n{POSTFIX}", "utf-8") else: pem_data = bytes(pem_data, "utf-8") cert = x509.load_pem_x509_certificate(pem_data) @@ -106,9 +106,9 @@ def load_x509_cert(url, httpc, spec2key, **get_args): elif isinstance(public_key, ec.EllipticCurvePublicKey): return {"ec": public_key} else: - raise Exception("HTTP Get error: %s" % r.status_code) + raise Exception(f"HTTP Get error: {r.status_code}") except Exception as err: # not a RSA key - logger.warning("Can't load key: %s" % err) + logger.warning(f"Can't load key: {err}") return [] diff --git a/src/cryptojwt/jws/__init__.py b/src/cryptojwt/jws/__init__.py index d77eb749..b3482b3c 100644 --- a/src/cryptojwt/jws/__init__.py +++ b/src/cryptojwt/jws/__init__.py @@ -1,4 +1,4 @@ -class Signer(object): +class Signer: """Abstract base class for signing algorithms.""" def sign(self, msg, key): diff --git a/src/cryptojwt/jws/dsa.py b/src/cryptojwt/jws/dsa.py index 6ddedcfd..536d14f6 100644 --- a/src/cryptojwt/jws/dsa.py +++ b/src/cryptojwt/jws/dsa.py @@ -1,5 +1,3 @@ -import sys - from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import ec @@ -26,7 +24,7 @@ def __init__(self, algorithm="ES256"): self.hash_algorithm = hashes.SHA512 self.curve_name = "secp521r1" else: - raise Unsupported("algorithm: {}".format(algorithm)) + raise Unsupported(f"algorithm: {algorithm}") self.algorithm = algorithm @@ -77,8 +75,8 @@ def verify(self, msg, sig, key): (r, s) = self._split_raw_signature(sig) asn1sig = encode_dss_signature(r, s) key.verify(asn1sig, msg, ec.ECDSA(self.hash_algorithm())) - except InvalidSignature as err: - raise BadSignature(err) + except InvalidSignature as exc: + raise BadSignature(exc) from exc else: return True @@ -92,8 +90,8 @@ def _cross_check(self, pub_key): """ if self.curve_name != pub_key.curve.name: raise ValueError( - "The curve in private key {} and in algorithm {} don't " - "match".format(pub_key.curve.name, self.curve_name) + f"The curve in private key {pub_key.curve.name} and in algorithm {self.curve_name} don't " + "match" ) @staticmethod diff --git a/src/cryptojwt/jws/eddsa.py b/src/cryptojwt/jws/eddsa.py index 6a88f8e0..cf88ba50 100644 --- a/src/cryptojwt/jws/eddsa.py +++ b/src/cryptojwt/jws/eddsa.py @@ -1,11 +1,8 @@ -import sys - from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric import ed448 from cryptography.hazmat.primitives.asymmetric import ed25519 from ..exception import BadSignature -from ..exception import Unsupported from . import Signer @@ -65,7 +62,7 @@ def verify(self, msg, sig, key): try: key.verify(sig, msg) - except InvalidSignature as err: - raise BadSignature(err) + except InvalidSignature as exc: + raise BadSignature(exc) from exc else: return True diff --git a/src/cryptojwt/jws/hmac.py b/src/cryptojwt/jws/hmac.py index e084e224..775dbd2f 100644 --- a/src/cryptojwt/jws/hmac.py +++ b/src/cryptojwt/jws/hmac.py @@ -14,7 +14,7 @@ def __init__(self, algorithm="SHA256"): elif algorithm == "SHA512": self.algorithm = hashes.SHA512 else: - raise Unsupported("algorithm: {}".format(algorithm)) + raise Unsupported(f"algorithm: {algorithm}") def sign(self, msg, key): """ @@ -44,5 +44,5 @@ def verify(self, msg, sig, key): h.update(msg) h.verify(sig) return True - except: + except Exception: return False diff --git a/src/cryptojwt/jws/jws.py b/src/cryptojwt/jws/jws.py index 9306a297..686b5f80 100644 --- a/src/cryptojwt/jws/jws.py +++ b/src/cryptojwt/jws/jws.py @@ -24,12 +24,6 @@ from .rsa import RSASigner from .utils import alg2keytype -try: - from builtins import object - from builtins import str -except ImportError: - pass - logger = logging.getLogger(__name__) KDESC = ["use", "kid", "kty"] @@ -99,10 +93,10 @@ def alg_keys(self, keys, use, protected=None): else: if "kid" in self: raise NoSuitableSigningKeys( - "No key for algorithm: %s and kid: %s" % (_alg, self["kid"]) + "No key for algorithm: {} and kid: {}".format(_alg, self["kid"]) ) else: - raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg) + raise NoSuitableSigningKeys(f"No key for algorithm: {_alg}") return key, xargs, _alg @@ -133,8 +127,8 @@ def sign_compact(self, keys=None, protected=None, **kwargs): # All other cases try: _signer = SIGNER_ALGS[_alg] - except KeyError: - raise UnknownAlgorithm(_alg) + except KeyError as exc: + raise UnknownAlgorithm(_alg) from exc _input = jwt.pack(parts=[self.msg]) @@ -143,7 +137,7 @@ def sign_compact(self, keys=None, protected=None, **kwargs): else: sig = _signer.sign(_input.encode("utf-8"), key.key) - logger.debug("Signed message using key with kid=%s" % key.kid) + logger.debug(f"Signed message using key with kid={key.kid}") return ".".join([_input, b64encode_item(sig).decode("utf-8")]) def verify_compact(self, jws=None, keys=None, allow_none=False, sigalg=None): @@ -207,30 +201,24 @@ def verify_compact_verbose(self, jws=None, keys=None, allow_none=False, sigalg=N ) if sigalg and sigalg != _alg: - raise SignerAlgError("Expected {0} got {1}".format(sigalg, jwt.headers["alg"])) + raise SignerAlgError("Expected {} got {}".format(sigalg, jwt.headers["alg"])) self["alg"] = _alg - if keys: - _keys = self.pick_keys(keys) - else: - _keys = self.pick_keys(self._get_keys()) + _keys = self.pick_keys(keys) if keys else self.pick_keys(self._get_keys()) if not _keys: if "kid" in self: - raise NoSuitableSigningKeys("No key with kid: %s" % (self["kid"])) + raise NoSuitableSigningKeys("No key with kid: {}".format(self["kid"])) elif "kid" in self.jwt.headers: - raise NoSuitableSigningKeys("No key with kid: %s" % (self.jwt.headers["kid"])) + raise NoSuitableSigningKeys("No key with kid: {}".format(self.jwt.headers["kid"])) else: - raise NoSuitableSigningKeys("No key for algorithm: %s" % _alg) + raise NoSuitableSigningKeys(f"No key for algorithm: {_alg}") verifier = SIGNER_ALGS[_alg] for key in _keys: - if isinstance(key, AsymmetricKey): - _key = key.public_key() - else: - _key = key.key + _key = key.public_key() if isinstance(key, AsymmetricKey) else key.key try: if not verifier.verify(jwt.sign_input(), jwt.signature(), _key): @@ -238,9 +226,9 @@ def verify_compact_verbose(self, jws=None, keys=None, allow_none=False, sigalg=N except (BadSignature, IndexError): pass except (ValueError, TypeError) as err: - logger.warning('Exception "{}" caught'.format(err)) + logger.warning(f'Exception "{err}" caught') else: - logger.debug("Verified message using key with kid=%s" % key.kid) + logger.debug(f"Verified message using key with kid={key.kid}") self.msg = jwt.payload() self.key = key self._protected_headers = jwt.headers.copy() @@ -310,8 +298,8 @@ def verify_json(self, jws, keys=None, allow_none=False, at_least_one=False): try: _payload = _jwss["payload"] - except KeyError: - raise FormatError("Missing payload") + except KeyError as exc: + raise FormatError("Missing payload") from exc try: _signs = _jwss["signatures"] @@ -347,13 +335,11 @@ def verify_json(self, jws, keys=None, allow_none=False, at_least_one=False): _tmp = self.verify_compact(token, keys, allow_none) except NoSuitableSigningKeys: if at_least_one is True: - logger.warning( - "Could not verify signature with headers: {}".format(all_headers) - ) + logger.warning(f"Could not verify signature with headers: {all_headers}") continue else: raise - except JWSException as err: + except JWSException: raise if _claim is None: @@ -412,7 +398,7 @@ def _is_compact_jws(self, jws): try: jwt = JWSig().unpack(jws) except Exception as err: - logger.warning("Could not parse JWS: {}".format(err)) + logger.warning(f"Could not parse JWS: {err}") return False if "alg" not in jwt.headers: @@ -422,7 +408,7 @@ def _is_compact_jws(self, jws): jwt.headers["alg"] = "none" if jwt.headers["alg"] not in SIGNER_ALGS: - logger.debug("UnknownSignerAlg: %s" % jwt.headers["alg"]) + logger.debug("UnknownSignerAlg: {}".format(jwt.headers["alg"])) return False self.jwt = jwt diff --git a/src/cryptojwt/jws/pss.py b/src/cryptojwt/jws/pss.py index 3d3da8d5..65362a47 100644 --- a/src/cryptojwt/jws/pss.py +++ b/src/cryptojwt/jws/pss.py @@ -21,7 +21,7 @@ def __init__(self, algorithm="SHA256"): elif algorithm == "SHA512": self.hash_algorithm = hashes.SHA512 else: - raise Unsupported("algorithm: {}".format(algorithm)) + raise Unsupported(f"algorithm: {algorithm}") def sign(self, msg, key): """ @@ -64,7 +64,7 @@ def verify(self, msg, signature, key): ), self.hash_algorithm(), ) - except InvalidSignature as err: - raise BadSignature(err) + except InvalidSignature as exc: + raise BadSignature(exc) from exc else: return True diff --git a/src/cryptojwt/jws/rsa.py b/src/cryptojwt/jws/rsa.py index 566c7538..414ec14c 100644 --- a/src/cryptojwt/jws/rsa.py +++ b/src/cryptojwt/jws/rsa.py @@ -42,8 +42,8 @@ def verify(self, msg, signature, key): raise TypeError("The public key must be an instance of RSAPublicKey") try: key.verify(signature, msg, self.padding, self.hash) - except InvalidSignature as err: - raise BadSignature(str(err)) + except InvalidSignature as exc: + raise BadSignature(str(exc)) from exc except AttributeError: return False else: diff --git a/src/cryptojwt/jws/utils.py b/src/cryptojwt/jws/utils.py index 40ac46f2..c0252825 100644 --- a/src/cryptojwt/jws/utils.py +++ b/src/cryptojwt/jws/utils.py @@ -47,9 +47,7 @@ def alg2keytype(alg): return "RSA" elif alg.startswith("HS") or alg.startswith("A"): return "oct" - elif alg == "Ed25519": - return "OKP" - elif alg == "Ed448": + elif alg == "Ed25519" or alg == "Ed448": return "OKP" elif alg.startswith("ES") or alg.startswith("ECDH-ES"): return "EC" @@ -91,4 +89,4 @@ def parse_rsa_algorithm(algorithm): padding.PSS(mgf=padding.MGF1(hashes.SHA512()), salt_length=padding.PSS.MAX_LENGTH), ) else: - raise UnsupportedAlgorithm("Unknown algorithm: {}".format(algorithm)) + raise UnsupportedAlgorithm(f"Unknown algorithm: {algorithm}") diff --git a/src/cryptojwt/jwt.py b/src/cryptojwt/jwt.py index 37ee035d..75cc0d9b 100755 --- a/src/cryptojwt/jwt.py +++ b/src/cryptojwt/jwt.py @@ -1,5 +1,6 @@ """Basic JSON Web Token implementation.""" +import contextlib import json import logging import time @@ -50,10 +51,7 @@ def pick_key(keys, use, alg="", key_type="", kid=""): """ res = [] if not key_type: - if use == "sig": - key_type = jws_alg2keytype(alg) - else: - key_type = jwe_alg2keytype(alg) + key_type = jws_alg2keytype(alg) if use == "sig" else jwe_alg2keytype(alg) for key in keys: if key.use and key.use != use: @@ -67,7 +65,7 @@ def pick_key(keys, use, alg="", key_type="", kid=""): if key.alg == "" and alg: if key_type == "EC": - if key.crv != "P-{}".format(alg[2:]): + if key.crv != f"P-{alg[2:]}": continue elif alg and key.alg != alg: continue @@ -143,10 +141,8 @@ def receivers(self): def my_keys(self, issuer_id="", use="sig"): _k = self.key_jar.get(use, issuer_id=issuer_id) if issuer_id != "": - try: + with contextlib.suppress(KeyError): _k.extend(self.key_jar.get(use, issuer_id="")) - except KeyError: - pass return _k def _encrypt(self, payload, recv, cty="JWT", zip=""): @@ -209,7 +205,7 @@ def pack_key(self, issuer_id="", kid=""): keys = pick_key(self.my_keys(issuer_id, "sig"), "sig", alg=self.alg, kid=kid) if not keys: - raise NoSuitableSigningKeys("kid={}".format(kid)) + raise NoSuitableSigningKeys(f"kid={kid}") return keys[0] # Might be more then one if kid == '' @@ -225,7 +221,7 @@ def pack( aud: Optional[str] = None, iat: Optional[int] = None, jws_headers: Optional[Dict[str, str]] = None, - **kwargs + **kwargs, ) -> str: """ @@ -264,11 +260,7 @@ def pack( issuer_id = self.iss if self.sign: - if self.alg != "none": - _key = self.pack_key(issuer_id, kid) - # _args['kid'] = _key.kid - else: - _key = None + _key = self.pack_key(issuer_id, kid) if self.alg != "none" else None jws_headers = jws_headers or {} @@ -442,9 +434,7 @@ def remove_jwt_parameters(arg): """ for param in JWT.jwt_parameters: - try: + with contextlib.suppress(KeyError): del arg[param] - except KeyError: - pass return arg diff --git a/src/cryptojwt/jwx.py b/src/cryptojwt/jwx.py index 7d4239d9..ac214fb5 100644 --- a/src/cryptojwt/jwx.py +++ b/src/cryptojwt/jwx.py @@ -89,13 +89,13 @@ def __init__(self, msg=None, with_digest=False, httpc=None, **kwargs): try: _spec = load_x509_cert(self["x5u"], self.httpc, {}) self._jwk = RSAKey(pub_key=_spec["rsa"]).to_dict() - except Exception: + except Exception as exc: # ca_chain = load_x509_cert_chain(self["x5u"]) - raise ValueError("x5u") + raise ValueError("x5u") from exc else: self._dict[key] = _val if key in DEPRECATED and _val in DEPRECATED[key]: - warnings.warn(f"{key}={_val} deprecated") + warnings.warn(f"{key}={_val} deprecated", stacklevel=1) def _set_jwk(self, val): if isinstance(val, dict): @@ -123,8 +123,8 @@ def __setitem__(self, key, value): def __getattr__(self, item): try: return self._dict[item] - except KeyError: - raise AttributeError(item) + except KeyError as exc: + raise AttributeError(item) from exc def keys(self): """Return all keys.""" diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 0708d0ec..c06e94e2 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -1,5 +1,6 @@ """Implementation of a Key Bundle.""" +import contextlib import copy import json import logging @@ -276,13 +277,7 @@ def __init__( if keys: self.source = None - if isinstance(keys, dict): - if "keys" in keys: - initial_keys = keys["keys"] - else: - initial_keys = [keys] - else: - initial_keys = keys + initial_keys = keys.get("keys", [keys]) if isinstance(keys, dict) else keys self._keys = self.jwk_dicts_as_keys(initial_keys) else: self._set_source(source, fileformat) @@ -372,10 +367,10 @@ def jwk_dicts_as_keys(self, keys): for _use in _usage: try: _key = K2C[_typ](use=_use, **inst) - except KeyError: + except KeyError as exc: if not self.ignore_invalid_keys: - raise UnknownKeyType(inst) - _error = "UnknownKeyType: {}".format(_typ) + raise UnknownKeyType(inst) from exc + _error = f"UnknownKeyType: {_typ}" continue except (UnsupportedECurve, UnsupportedAlgorithm) as err: if not self.ignore_invalid_keys: @@ -441,7 +436,7 @@ def _do_local_der(self, filename, keytype, keyusage=None, kid=""): key_args["priv_key"] = _key key_args["pub_key"] = _key.public_key() else: - raise NotImplementedError("No support for DER decoding of key type {}".format(_kty)) + raise NotImplementedError(f"No support for DER decoding of key type {_kty}") if not keyusage: key_args["use"] = ["enc", "sig"] @@ -484,7 +479,7 @@ def _do_remote(self, set_keys=True): _http_resp = self.httpc("GET", self.source, **httpc_params) except Exception as err: LOGGER.error(err) - raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err))) + raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err))) from err new_keys = None load_successful = _http_resp.status_code == 200 @@ -500,13 +495,13 @@ def _do_remote(self, set_keys=True): LOGGER.debug("Loaded JWKS: %s from %s", _http_resp.text, self.source) try: new_keys = self.jwk_dicts_as_keys(self.imp_jwks["keys"]) - except KeyError: + except KeyError as exc: LOGGER.error("No 'keys' keyword in JWKS") self.ignore_errors_until = time.time() + self.ignore_errors_period - raise UpdateFailed(MALFORMED.format(self.source)) + raise UpdateFailed(MALFORMED.format(self.source)) from exc if hasattr(_http_resp, "headers"): - headers = getattr(_http_resp, "headers") + headers = _http_resp.headers self.last_remote = headers.get("last-modified") or headers.get("date") elif not_modified: LOGGER.debug("%s not modified since %s", self.source, self.last_remote) @@ -646,7 +641,7 @@ def remove_keys_by_type(self, typ): :param typ: Type of key (rsa, ec, oct, ..) """ _typs = [typ.lower(), typ.upper()] - self._keys = [k for k in self._keys if not k.kty in _typs] + self._keys = [k for k in self._keys if k.kty not in _typs] def __str__(self): return str(self.jwks()) @@ -690,10 +685,8 @@ def remove(self, key): :param key: The key that should be removed """ - try: + with contextlib.suppress(ValueError): self._keys.remove(key) - except ValueError: - pass def __len__(self): """ @@ -833,7 +826,7 @@ def dump(self, exclude_attributes: Optional[List[str]] = None): _keys.append(_ser) res["keys"] = _keys - for attr, default in self.params.items(): + for attr in self.params: if attr in exclude_attributes: continue val = getattr(self, attr) @@ -856,7 +849,7 @@ def load(self, spec): self._keys.extend(self.jwk_dicts_as_keys(_keys)) self.last_updated = time.time() - for attr, default in self.params.items(): + for attr in self.params: val = spec.get(attr) if val: setattr(self, attr, val) @@ -922,12 +915,18 @@ def dump_jwks(kbl, target, private=False, symmetric_too=False): keys = [] for _bundle in kbl: if symmetric_too: - keys.extend([k.serialize(private) for k in _bundle.keys() if not k.inactive_since]) + keys.extend( + [ + k.serialize(private) + for k in _bundle.keys() # noqa: SIM118 + if not k.inactive_since + ] + ) else: keys.extend( [ k.serialize(private) - for k in _bundle.keys() + for k in _bundle.keys() # noqa: SIM118 if k.kty != "oct" and not k.inactive_since ] ) @@ -935,15 +934,13 @@ def dump_jwks(kbl, target, private=False, symmetric_too=False): res = {"keys": keys} try: - _fp = open(target, "w") - except IOError: + with open(target, "w") as fp: + json.dump(res, fp) + except OSError: head, _ = os.path.split(target) os.makedirs(head) - _fp = open(target, "w") - - _txt = json.dumps(res) - _fp.write(_txt) - _fp.close() + with open(target, "w") as fp: + json.dump(res, fp) def _set_kid(spec, bundle, kid_template, kid): @@ -951,7 +948,7 @@ def _set_kid(spec, bundle, kid_template, kid): _keys = bundle.keys() _keys[0].kid = spec["kid"] else: - for k in bundle.keys(): + for k in bundle.keys(): # noqa: SIM118 if kid_template: k.kid = kid_template % kid kid += 1 @@ -1010,7 +1007,7 @@ def build_key_bundle(key_conf, kid_template=""): if "key" in spec and spec["key"]: if os.path.isfile(spec["key"]): _bundle = KeyBundle( - source="file://%s" % spec["key"], + source="file://{}".format(spec["key"]), fileformat="der", keytype=typ, keyusage=spec["use"], @@ -1021,7 +1018,7 @@ def build_key_bundle(key_conf, kid_template=""): if "key" in spec and spec["key"]: if os.path.isfile(spec["key"]): _bundle = KeyBundle( - source="file://%s" % spec["key"], + source="file://{}".format(spec["key"]), fileformat="der", keytype=typ, keyusage=spec["use"], @@ -1032,7 +1029,7 @@ def build_key_bundle(key_conf, kid_template=""): if "key" in spec and spec["key"]: if os.path.isfile(spec["key"]): _bundle = KeyBundle( - source="file://%s" % spec["key"], + source="file://{}".format(spec["key"]), fileformat="der", keytype=typ, keyusage=spec["use"], @@ -1325,13 +1322,13 @@ def key_gen(type, **kwargs): crv = kwargs.get("crv", DEFAULT_EC_CURVE) if crv not in NIST2SEC: logging.error("Unknown curve: %s", crv) - raise ValueError("Unknown curve: {}".format(crv)) + raise ValueError(f"Unknown curve: {crv}") _key = new_ec_key(crv=crv, **kargs) elif type.upper() == "OKP": crv = kwargs.get("crv", DEFAULT_OKP_CURVE) if crv not in OKP_CRV2PUBLIC: logging.error("Unknown curve: %s", crv) - raise ValueError("Unknown curve: {}".format(crv)) + raise ValueError(f"Unknown curve: {crv}") _key = new_okp_key(crv=crv, **kargs) elif type.lower() in ["sym", "oct"]: keysize = kwargs.get("bytes", 24) @@ -1339,7 +1336,7 @@ def key_gen(type, **kwargs): _key = SYMKey(key=randomkey, **kargs) else: logging.error("Unknown key type: %s", type) - raise ValueError("Unknown key type: %s".format(type)) + raise ValueError("Unknown key type: %s".format()) return _key @@ -1374,4 +1371,4 @@ def key_by_alg(alg: str): elif alg.startswith("HS"): return key_gen("sym") - raise ValueError("Don't know who to create a key to use with '{}'".format(alg)) + raise ValueError(f"Don't know who to create a key to use with '{alg}'") diff --git a/src/cryptojwt/key_issuer.py b/src/cryptojwt/key_issuer.py index 08e540ac..6312940f 100755 --- a/src/cryptojwt/key_issuer.py +++ b/src/cryptojwt/key_issuer.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -class KeyIssuer(object): +class KeyIssuer: """A key issuer instance contains a number of KeyBundles.""" params = { @@ -65,7 +65,7 @@ def __init__( self.spec2key = {} def __repr__(self) -> str: - return ''.format(self.name, self.key_summary()) + return f'' def __getitem__(self, item): return self.get_bundles()[item] @@ -155,10 +155,7 @@ def all_keys(self): return res def __contains__(self, item): - for kb in self._bundles: - if item in kb: - return True - return False + return any(item in kb for kb in self._bundles) def items(self): _res = {} @@ -208,7 +205,7 @@ def export_jwks(self, private=False, usage=None): keys.extend( [ k.serialize(private) - for k in kb.keys() + for k in kb.keys() # noqa: SIM118 if k.inactive_since == 0 and (usage is None or (hasattr(k, "use") and k.use == usage)) ] @@ -232,8 +229,8 @@ def import_jwks(self, jwks): """ try: _keys = jwks["keys"] - except KeyError: - raise ValueError("Not a proper JWKS") + except KeyError as exc: + raise ValueError("Not a proper JWKS") from exc else: self._bundles.append( self.keybundle_cls(_keys, httpc=self.httpc, httpc_params=self.httpc_params) @@ -285,10 +282,7 @@ def get(self, key_use, key_type="", kid=None, alg="", **kwargs): :return: A possibly empty list of keys """ - if key_use in ["dec", "enc"]: - use = "enc" - else: - use = "sig" + use = "enc" if key_use in ["dec", "enc"] else "sig" if not key_type: if alg: @@ -327,7 +321,7 @@ def get(self, key_use, key_type="", kid=None, alg="", **kwargs): # if elliptic curve, have to check if I have a key of the right curve if key_type and key_type.upper() == "EC": if alg: - name = "P-{}".format(alg[2:]) # the type + name = f"P-{alg[2:]}" # the type _lst = [] for key in lst: if name != key.crv: @@ -372,7 +366,7 @@ def dump(self, exclude_attributes: Optional[List[str]] = None) -> dict: exclude_attributes = [] info = {} - for attr, default in self.params.items(): + for attr in self.params: if attr in exclude_attributes: continue val = getattr(self, attr) @@ -394,7 +388,7 @@ def load(self, info): :param items: A dictionary with the information to load :return: """ - for attr, default in self.params.items(): + for attr in self.params: val = info.get(attr) if val: if attr == "keybundle_cls": @@ -441,16 +435,15 @@ def key_summary(self) -> str: """ key_list = [] for kb in self._bundles: - for key in kb.keys(): + for key in kb.keys(): # noqa: SIM118 if key.inactive_since: - key_list.append("*{}:{}:{}".format(key.kty, key.use, key.kid)) + key_list.append(f"*{key.kty}:{key.use}:{key.kid}") else: - key_list.append("{}:{}:{}".format(key.kty, key.use, key.kid)) + key_list.append(f"{key.kty}:{key.use}:{key.kid}") return ", ".join(key_list) def __iter__(self): - for bundle in self._bundles: - yield bundle + yield from self._bundles def __eq__(self, other): if not isinstance(other, self.__class__): @@ -463,11 +456,7 @@ def __eq__(self, other): if k not in other: return False - for k in other.all_keys(): - if k not in self: - return False - - return True + return all(k in self for k in other.all_keys()) def rotate_keys(self, key_conf, kid_template=""): """ @@ -586,7 +575,8 @@ def init_key_issuer(public_path="", private_path="", key_defs="", read_only=True if private_path: if os.path.isfile(private_path): - _jwks = open(private_path, "r").read() + with open(private_path) as fp: + _jwks = fp.read() _issuer = KeyIssuer() _issuer.import_jwks(json.loads(_jwks)) if key_defs: @@ -599,9 +589,8 @@ def init_key_issuer(public_path="", private_path="", key_defs="", read_only=True else: _issuer.set([_kb]) jwks = _issuer.export_jwks(private=True) - fp = open(private_path, "w") - fp.write(json.dumps(jwks)) - fp.close() + with open(private_path, "w") as fp: + json.dump(jwks, fp) else: _issuer = build_keyissuer(key_defs) if not read_only: @@ -609,21 +598,20 @@ def init_key_issuer(public_path="", private_path="", key_defs="", read_only=True head, tail = os.path.split(private_path) if head and not os.path.isdir(head): os.makedirs(head) - fp = open(private_path, "w") - fp.write(json.dumps(jwks)) - fp.close() + with open(private_path, "w") as fp: + json.dump(jwks, fp) if public_path and not read_only: jwks = _issuer.export_jwks() # public part head, tail = os.path.split(public_path) if head and not os.path.isdir(head): os.makedirs(head) - fp = open(public_path, "w") - fp.write(json.dumps(jwks)) - fp.close() + with open(public_path, "w") as fp: + json.dump(jwks, fp) elif public_path: if os.path.isfile(public_path): - _jwks = open(public_path, "r").read() + with open(public_path) as fp: + _jwks = fp.read() _issuer = KeyIssuer() _issuer.import_jwks(json.loads(_jwks)) if key_defs: @@ -636,9 +624,8 @@ def init_key_issuer(public_path="", private_path="", key_defs="", read_only=True update_key_bundle(_kb, _diff) _issuer.set([_kb]) jwks = _issuer.export_jwks() - fp = open(public_path, "w") - fp.write(json.dumps(jwks)) - fp.close() + with open(public_path, "w") as fp: + json.dump(jwks, fp) else: _issuer = build_keyissuer(key_defs) if not read_only: @@ -646,9 +633,8 @@ def init_key_issuer(public_path="", private_path="", key_defs="", read_only=True head, tail = os.path.split(public_path) if head and not os.path.isdir(head): os.makedirs(head) - fp = open(public_path, "w") - fp.write(json.dumps(_jwks)) - fp.close() + with open(public_path, "w") as fp: + json.dump(_jwks, fp) else: _issuer = build_keyissuer(key_defs) diff --git a/src/cryptojwt/key_jar.py b/src/cryptojwt/key_jar.py index 2bec768e..813efc73 100755 --- a/src/cryptojwt/key_jar.py +++ b/src/cryptojwt/key_jar.py @@ -1,3 +1,4 @@ +import contextlib import json import logging from typing import List @@ -24,7 +25,7 @@ logger = logging.getLogger(__name__) -class KeyJar(object): +class KeyJar: """A keyjar contains a number of KeyBundles sorted by owner/issuer""" def __init__( @@ -102,7 +103,7 @@ def items(self): def __repr__(self): issuers = self._issuer_ids() - return "".format(issuers) + return f"" @deprecated_alias(issuer="issuer_id", owner="issuer_id") def return_issuer(self, issuer_id): @@ -255,10 +256,7 @@ def get_issuer_keys(self, issuer_id): @deprecated_alias(issuer="issuer_id", owner="issuer_id") def __contains__(self, issuer_id): _iss = self._get_issuer(issuer_id) - if _iss is None: - return False - else: - return True + return _iss is not None @deprecated_alias(issuer="issuer_id", owner="issuer_id") def __getitem__(self, issuer_id=""): @@ -303,11 +301,11 @@ def match_owner(self, url): :param url: A URL :return: An issue entity ID that exists in the Key jar """ - _iss = [i for i in self._issuers.keys() if i.startswith(url)] + _iss = [i for i in self._issuers.keys() if i.startswith(url)] # noqa: SIM118 if _iss: return _iss[0] - raise KeyError("No keys for '{}' in this keyjar".format(url)) + raise KeyError(f"No keys for '{url}' in this keyjar") def __str__(self): _res = {} @@ -385,7 +383,7 @@ def export_jwks(self, private=False, issuer_id="", usage=None): keys.extend( [ k.serialize(private) - for k in kb.keys() + for k in kb.keys() # noqa: SIM118 if k.inactive_since == 0 and (usage is None or (hasattr(k, "use") and k.use == usage)) ] @@ -413,8 +411,8 @@ def import_jwks(self, jwks, issuer_id): """ try: _keys = jwks["keys"] - except KeyError: - raise ValueError("Not a proper JWKS") + except KeyError as exc: + raise ValueError("Not a proper JWKS") from exc if _keys: _issuer = self.return_issuer(issuer_id=issuer_id) @@ -446,11 +444,7 @@ def __eq__(self, other): return False # Keys per issuer must be the same - for iss in self.owners(): - if self[iss] != other[iss]: - return False - - return True + return all(self[iss] == other[iss] for iss in self.owners()) def __delitem__(self, key): del self._issuers[key] @@ -484,10 +478,10 @@ def _add_key( ): _issuer = self._get_issuer(issuer_id) if _issuer is None: - logger.error('Issuer "{}" not in keyjar'.format(issuer_id)) + logger.error(f'Issuer "{issuer_id}" not in keyjar') raise IssuerNotFound(issuer_id) - logger.debug("Key summary for {}: {}".format(issuer_id, _issuer.key_summary())) + logger.debug(f"Key summary for {issuer_id}: {_issuer.key_summary()}") if kid: for _key in _issuer.get(use, kid=kid, key_type=key_type): @@ -669,16 +663,10 @@ def dump( if exclude_attributes: for attr in exclude_attributes: - try: + with contextlib.suppress(KeyError): del info[attr] - except KeyError: - pass - if exclude_attributes is None: - info["issuers"] = self._dump_issuers( - exclude_issuers=exclude_issuers, exclude_attributes=exclude_attributes - ) - elif "issuers" not in exclude_attributes: + if exclude_attributes is None or "issuers" not in exclude_attributes: info["issuers"] = self._dump_issuers( exclude_issuers=exclude_issuers, exclude_attributes=exclude_attributes ) @@ -706,13 +694,13 @@ def load( :param info: A dictionary with the information :return: """ - self.ca_certs = info.get("ca_certs", None) - self.httpc_params = info.get("httpc_params", None) + self.ca_certs = info.get("ca_certs") + self.httpc_params = info.get("httpc_params") self.keybundle_cls = importer(info.get("keybundle_cls", KeyBundle)) self.remove_after = info.get("remove_after", 3600) self.spec2key = info.get("spec2key", {}) - _issuers = info.get("issuers", None) + _issuers = info.get("issuers") if _issuers is None: self._issuers = {} else: diff --git a/src/cryptojwt/simple_jwt.py b/src/cryptojwt/simple_jwt.py index fbdaf2b8..f2c45154 100644 --- a/src/cryptojwt/simple_jwt.py +++ b/src/cryptojwt/simple_jwt.py @@ -1,3 +1,4 @@ +import contextlib import json import logging @@ -13,7 +14,7 @@ logger = logging.getLogger(__name__) -class SimpleJWT(object): +class SimpleJWT: """ Basic JSON Web Token class that doesn't make any assumptions as to what can or should be in the payload @@ -36,10 +37,8 @@ def unpack(self, token, **kwargs): against. """ if isinstance(token, str): - try: + with contextlib.suppress(UnicodeDecodeError): token = token.encode("utf-8") - except UnicodeDecodeError: - pass part = split_token(token) self.b64part = part @@ -55,9 +54,7 @@ def unpack(self, token, **kwargs): raise else: if not _ok: - raise HeaderError( - 'Expected "{}" to be "{}", was "{}"'.format(key, val, self.headers[key]) - ) + raise HeaderError(f'Expected "{key}" to be "{val}", was "{self.headers[key]}"') return self @@ -70,12 +67,9 @@ def pack(self, parts=None, headers=None): :return: """ if not headers: - if self.headers: - headers = self.headers - else: - headers = {"alg": "none"} + headers = self.headers if self.headers else {"alg": "none"} - logging.debug("(pack) JWT header: {}".format(headers)) + logging.debug(f"(pack) JWT header: {headers}") if not parts: return ".".join([a.decode() for a in self.b64part]) @@ -100,10 +94,8 @@ def payload(self): if "cty" in self.headers and self.headers["cty"].lower() != "jwt": pass else: - try: + with contextlib.suppress(ValueError): _msg = json.loads(_msg) - except ValueError: - pass return _msg @@ -119,15 +111,9 @@ def verify_header(self, key, val): """ if isinstance(val, list): - if self.headers[key] in val: - return True - else: - return False + return self.headers[key] in val else: - if self.headers[key] == val: - return True - else: - return False + return self.headers[key] == val def verify_headers(self, check_presence=True, **kwargs): """ diff --git a/src/cryptojwt/tools/jwtpeek.py b/src/cryptojwt/tools/jwtpeek.py index a0f420e7..2d568717 100755 --- a/src/cryptojwt/tools/jwtpeek.py +++ b/src/cryptojwt/tools/jwtpeek.py @@ -18,7 +18,6 @@ from cryptojwt.jws import jws from cryptojwt.key_bundle import KeyBundle from cryptojwt.key_issuer import KeyIssuer -from cryptojwt.key_jar import KeyJar __author__ = "roland" @@ -53,7 +52,7 @@ def process(jwt, keys, quiet): if _jw: if not quiet: print("Encrypted JSON Web Token") - print("Headers: {}".format(_jw.jwt.headers)) + print(f"Headers: {_jw.jwt.headers}") if keys: res = _jw.decrypt(keys=keys) json_object = json.loads(res) @@ -71,17 +70,15 @@ def process(jwt, keys, quiet): print(highlight(json_str, JsonLexer(), TerminalFormatter())) else: print("Signed JSON Web Token") - print("Headers: {}".format(_jw.jwt.headers)) + print(f"Headers: {_jw.jwt.headers}") if keys: res = _jw.verify_compact(keys=keys) - print("Verified message: {}".format(res)) + print(f"Verified message: {res}") else: json_object = json.loads(_jw.jwt.part[1].decode("utf-8")) json_str = json.dumps(json_object, indent=2) print( - "Unverified message: {}".format( - highlight(json_str, JsonLexer(), TerminalFormatter()) - ) + f"Unverified message: {highlight(json_str, JsonLexer(), TerminalFormatter())}" ) @@ -103,10 +100,7 @@ def main(): args = parser.parse_args() - if args.kid: - _kid = args.kid - else: - _kid = "" + _kid = args.kid if args.kid else "" keys = [] if args.rsa_file: @@ -115,25 +109,26 @@ def main(): keys.append(SYMKey(key=args.hmac_key, kid=_kid)) if args.jwk: - _key = key_from_jwk_dict(open(args.jwk).read()) + with open(args.jwk) as fp: + _key = key_from_jwk_dict(fp.read()) keys.append(_key) if args.jwks: _iss = KeyIssuer() - _iss.import_jwks(open(args.jwks).read()) + with open(args.jwks) as fp: + _iss.import_jwks(fp.read()) keys.extend(_iss.all_keys()) if args.jwks_url: _kb = KeyBundle(source=args.jwks_url) keys.extend(_kb.get()) - if not args.msg: # If nothing specified assume stdin - message = sys.stdin.read() - elif args.msg == "-": + if not args.msg or args.msg == "-": # If nothing specified assume stdin message = sys.stdin.read() else: if os.path.isfile(args.msg): - message = open(args.msg).read().strip("\n") + with open(args.msg) as fp: + message = fp.read().strip("\n") else: message = args.msg diff --git a/src/cryptojwt/tools/keyconv.py b/src/cryptojwt/tools/keyconv.py index 969a2b01..5a9234d6 100644 --- a/src/cryptojwt/tools/keyconv.py +++ b/src/cryptojwt/tools/keyconv.py @@ -25,7 +25,7 @@ def jwk_from_file(filename: str, private: bool = True) -> JWK: """Read JWK from file""" - with open(filename, mode="rt") as input_file: + with open(filename) as input_file: jwk_dict = json.loads(input_file.read()) return key_from_jwk_dict(jwk_dict, private=private) @@ -93,7 +93,7 @@ def pem2jwk( passphrase: Optional[str] = None, ) -> JWK: """Read PEM from filename and return JWK""" - with open(filename, "rt") as file: + with open(filename) as file: content = file.readlines() header = content[0] @@ -178,7 +178,7 @@ def output_jwk(jwk: JWK, private: bool = False, filename: Optional[str] = None) """Output JWK to file""" serialized = jwk.serialize(private=private) if filename is not None: - with open(filename, mode="wt") as file: + with open(filename, mode="w") as file: file.write(json.dumps(serialized)) else: print(json.dumps(serialized, indent=4)) diff --git a/src/cryptojwt/tools/keygen.py b/src/cryptojwt/tools/keygen.py index 3a028613..718bb628 100644 --- a/src/cryptojwt/tools/keygen.py +++ b/src/cryptojwt/tools/keygen.py @@ -11,7 +11,6 @@ from cryptojwt.jwk.okp import OKP_CRV2PUBLIC from cryptojwt.jwk.okp import new_okp_key from cryptojwt.jwk.rsa import new_rsa_key -from cryptojwt.utils import b64e DEFAULT_SYM_KEYSIZE = 32 DEFAULT_RSA_KEYSIZE = 2048 @@ -38,7 +37,7 @@ def main(): dest="rsa_exp", type=int, metavar="exponent", - help="RSA public key exponent (default {})".format(DEFAULT_RSA_EXP), + help=f"RSA public key exponent (default {DEFAULT_RSA_EXP})", default=DEFAULT_RSA_EXP, ) parser.add_argument("--kid", dest="kid", metavar="id", help="Key ID") @@ -50,12 +49,12 @@ def main(): jwk = new_rsa_key(public_exponent=args.rsa_exp, key_size=args.keysize, kid=args.kid) elif args.kty.upper() == "EC": if args.crv not in NIST2SEC: - print("Unknown curve: {0}".format(args.crv), file=sys.stderr) + print(f"Unknown curve: {args.crv}", file=sys.stderr) exit(1) jwk = new_ec_key(crv=args.crv, kid=args.kid) elif args.kty.upper() == "OKP": if args.crv not in OKP_CRV2PUBLIC: - print("Unknown curve: {0}".format(args.crv), file=sys.stderr) + print(f"Unknown curve: {args.crv}", file=sys.stderr) exit(1) jwk = new_okp_key(crv=args.crv, kid=args.kid) elif args.kty.upper() == "SYM" or args.kty.upper() == "OCT": @@ -63,7 +62,7 @@ def main(): args.keysize = DEFAULT_SYM_KEYSIZE jwk = new_sym_key(bytes=args.keysize, kid=args.kid) else: - print("Unknown key type: {}".format(args.kty), file=sys.stderr) + print(f"Unknown key type: {args.kty}", file=sys.stderr) exit(1) jwk_dict = jwk.serialize(private=True) diff --git a/src/cryptojwt/utils.py b/src/cryptojwt/utils.py index 5c13d91a..47785e02 100644 --- a/src/cryptojwt/utils.py +++ b/src/cryptojwt/utils.py @@ -1,4 +1,5 @@ import base64 +import contextlib import functools import importlib import json @@ -19,11 +20,11 @@ def intarr2bin(arr): - return unhexlify("".join(["%02x" % byte for byte in arr])) + return unhexlify("".join([f"{byte:02x}" for byte in arr])) def intarr2long(arr): - return int("".join(["%02x" % byte for byte in arr]), 16) + return int("".join([f"{byte:02x}" for byte in arr]), 16) def intarr2str(arr): @@ -44,7 +45,7 @@ def long_to_base64(n, mlen=0): _len = mlen - len(bys) if _len: bys = [0] * _len + bys - data = struct.pack("%sB" % len(bys), *bys) + data = struct.pack(f"{len(bys)}B", *bys) if not len(data): data = b"\x00" s = base64.urlsafe_b64encode(data).rstrip(b"=") @@ -57,7 +58,7 @@ def base64_to_long(data): # urlsafe_b64decode will happily convert b64encoded data _d = base64.urlsafe_b64decode(as_bytes(data) + b"==") - return intarr2long(struct.unpack("%sB" % len(_d), _d)) + return intarr2long(struct.unpack(f"{len(_d)}B", _d)) def base64url_to_long(data): @@ -74,7 +75,7 @@ def base64url_to_long(data): # that is no '+' and '/' characters and not trailing "="s. if [e for e in [b"+", b"/", b"="] if e in _data]: raise ValueError("Not base64url encoded") - return intarr2long(struct.unpack("%sB" % len(_d), _d)) + return intarr2long(struct.unpack(f"{len(_d)}B", _d)) # ============================================================================= @@ -140,10 +141,8 @@ def as_bytes(s): :param s: Unicode / bytes string :return: bytes string """ - try: + with contextlib.suppress(AttributeError, UnicodeDecodeError): s = s.encode() - except (AttributeError, UnicodeDecodeError): - pass return s @@ -154,10 +153,8 @@ def as_unicode(b): :param b: byte string :return: unicode string """ - try: + with contextlib.suppress(AttributeError, UnicodeDecodeError): b = b.decode() - except (AttributeError, UnicodeDecodeError): - pass return b @@ -172,7 +169,7 @@ def bytes2str_conv(item): elif isinstance(item, dict): return dict([(k, bytes2str_conv(v)) for k, v in item.items()]) - raise ValueError("Can't convert {}.".format(repr(item))) + raise ValueError(f"Can't convert {repr(item)}.") def b64encode_item(item): @@ -200,10 +197,7 @@ def deser(val): :param val: The string representation of the long integer. :return: The long integer. """ - if isinstance(val, str): - _val = val.encode("utf-8") - else: - _val = val + _val = val.encode("utf-8") if isinstance(val, str) else val return base64_to_long(_val) @@ -256,8 +250,8 @@ def rename_kwargs(func_name, kwargs, aliases): for alias, new in aliases.items(): if alias in kwargs: if new in kwargs: - raise TypeError("{} received both {} and {}".format(func_name, alias, new)) - warnings.warn("{} is deprecated; use {}".format(alias, new), DeprecationWarning) + raise TypeError(f"{func_name} received both {alias} and {new}") + warnings.warn(f"{alias} is deprecated; use {new}", DeprecationWarning, stacklevel=1) kwargs[new] = kwargs.pop(alias) diff --git a/tests/test_02_jwk.py b/tests/test_02_jwk.py index c01e85ad..3eeddd3e 100755 --- a/tests/test_02_jwk.py +++ b/tests/test_02_jwk.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -from __future__ import print_function import base64 import json @@ -9,11 +8,8 @@ import pytest from cryptography.hazmat.primitives.asymmetric import ec -from cryptography.hazmat.primitives.asymmetric import ed448 from cryptography.hazmat.primitives.asymmetric import ed25519 from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric import x448 -from cryptography.hazmat.primitives.asymmetric import x25519 from cryptojwt.exception import DeSerializationNotPossible from cryptojwt.exception import UnsupportedAlgorithm @@ -72,10 +68,10 @@ def _eq(l1, l2): def test_urlsafe_base64decode(): - l = base64_to_long(N) + length = base64_to_long(N) # convert it to base64 - bys = long2intarr(l) - data = struct.pack("%sB" % len(bys), *bys) + bys = long2intarr(length) + data = struct.pack(f"{len(bys)}B", *bys) if not len(data): data = "\x00" s0 = base64.b64encode(data) @@ -85,8 +81,8 @@ def test_urlsafe_base64decode(): base64url_to_long(s0) # Not else, should not raise exception - l = base64_to_long(s0) - assert l + length = base64_to_long(s0) + assert length def test_import_rsa_key_from_cert_file(): @@ -250,7 +246,8 @@ def test_get_key(): def test_private_rsa_key_from_jwk(): keys = [] - kspec = json.loads(open(full_path("jwk_private_key.json")).read()) + with open(full_path("jwk_private_key.json")) as fp: + kspec = json.loads(fp.read()) keys.append(key_from_jwk_dict(kspec)) key = keys[0] @@ -276,7 +273,8 @@ def test_private_rsa_key_from_jwk(): def test_public_key_from_jwk(): keys = [] - kspec = json.loads(open(full_path("jwk_private_key.json")).read()) + with open(full_path("jwk_private_key.json")) as fp: + kspec = json.loads(fp.read()) keys.append(key_from_jwk_dict(kspec, private=False)) key = keys[0] @@ -292,7 +290,8 @@ def test_public_key_from_jwk(): def test_ec_private_key_from_jwk(): keys = [] - kspec = json.loads(open(full_path("jwk_private_ec_key.json")).read()) + with open(full_path("jwk_private_ec_key.json")) as fp: + kspec = json.loads(fp.read()) keys.append(key_from_jwk_dict(kspec)) key = keys[0] @@ -310,7 +309,8 @@ def test_ec_private_key_from_jwk(): def test_ec_public_key_from_jwk(): keys = [] - kspec = json.loads(open(full_path("jwk_private_ec_key.json")).read()) + with open(full_path("jwk_private_ec_key.json")) as fp: + kspec = json.loads(fp.read()) keys.append(key_from_jwk_dict(kspec, private=False)) key = keys[0] @@ -566,7 +566,7 @@ def test_jwk_conversion(): def test_str(): _j = RSAKey(alg="RS512", use="sig", n=N, e=E) - s = "{}".format(_j) + s = f"{_j}" assert s.startswith("{") and s.endswith("}") sp = s.replace("'", '"') _d = json.loads(sp) diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index e4841924..cc1704c6 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -1,4 +1,5 @@ # pylint: disable=missing-docstring,no-self-use +import contextlib import json import os import shutil @@ -15,7 +16,6 @@ from cryptojwt.jwk.ec import new_ec_key from cryptojwt.jwk.hmac import SYMKey from cryptojwt.jwk.okp import OKPKey -from cryptojwt.jwk.okp import new_okp_key from cryptojwt.jwk.rsa import RSAKey from cryptojwt.jwk.rsa import import_rsa_key_from_cert_file from cryptojwt.jwk.rsa import new_rsa_key @@ -269,7 +269,7 @@ def test_get_all(): def test_keybundle_from_local_der(): - kb = keybundle_from_local_file("{}".format(RSA0), "der", ["enc"]) + kb = keybundle_from_local_file(f"{RSA0}", "der", ["enc"]) assert len(kb) == 1 keys = kb.get("rsa") assert len(keys) == 1 @@ -279,7 +279,7 @@ def test_keybundle_from_local_der(): def test_ec_keybundle_from_local_der(): - kb = keybundle_from_local_file("{}".format(EC0), "der", ["enc"], keytype="EC") + kb = keybundle_from_local_file(f"{EC0}", "der", ["enc"], keytype="EC") assert len(kb) == 1 keys = kb.get("ec") assert len(keys) == 1 @@ -289,7 +289,7 @@ def test_ec_keybundle_from_local_der(): def test_keybundle_from_local_der_update(): - kb = keybundle_from_local_file("file://{}".format(RSA0), "der", ["enc"]) + kb = keybundle_from_local_file(f"file://{RSA0}", "der", ["enc"]) assert len(kb) == 1 keys = kb.get("rsa") assert len(keys) == 1 @@ -409,7 +409,7 @@ def test_mark_as_inactive(): desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"} kb = KeyBundle([desc]) assert len(kb.keys()) == 1 - for k in kb.keys(): + for k in kb.keys(): # noqa kb.mark_as_inactive(k.kid) desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"} kb.add_jwk_dicts([desc]) @@ -421,7 +421,7 @@ def test_copy(): desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"} kb = KeyBundle([desc]) assert len(kb.keys()) == 1 - for k in kb.keys(): + for k in kb.keys(): # noqa kb.mark_as_inactive(k.kid) desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"} kb.add_jwk_dicts([desc]) @@ -433,14 +433,14 @@ def test_copy(): def test_local_jwk(): _path = full_path("jwk_private_key.json") - kb = KeyBundle(source="file://{}".format(_path)) + kb = KeyBundle(source=f"file://{_path}") assert kb def test_local_jwk_update(): cache_time = 0.1 _path = full_path("jwk_private_key.json") - kb = KeyBundle(source="file://{}".format(_path), cache_time=cache_time) + kb = KeyBundle(source=f"file://{_path}", cache_time=cache_time) assert kb _ = kb.keys() last1 = kb.last_local @@ -456,7 +456,7 @@ def test_local_jwk_update(): def test_local_jwk_copy(): _path = full_path("jwk_private_key.json") - kb = KeyBundle(source="file://{}".format(_path)) + kb = KeyBundle(source=f"file://{_path}") kb2 = kb.copy() assert kb2.source == kb.source @@ -483,7 +483,7 @@ def test_httpc_params_1(): httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds kb = KeyBundle(source=source, httpc=requests.request, httpc_params=httpc_params) updated, _ = kb._do_remote() - assert updated == True + assert updated is True @pytest.mark.network @@ -506,7 +506,7 @@ def test_update_2(): with open(fname, "w") as fp: fp.write(json.dumps(_jwks)) - kb = KeyBundle(source="file://{}".format(fname), fileformat="jwks") + kb = KeyBundle(source=f"file://{fname}", fileformat="jwks") assert len(kb) == 1 # Added one more key @@ -528,7 +528,7 @@ def test_update_mark_inactive(): with open(fname, "w") as fp: fp.write(json.dumps(_jwks)) - kb = KeyBundle(source="file://{}".format(fname), fileformat="jwks") + kb = KeyBundle(source=f"file://{fname}", fileformat="jwks") assert len(kb) == 1 # new set of keys @@ -909,7 +909,7 @@ def test_export_inactive(): desc = {"kty": "oct", "key": "highestsupersecret", "use": "sig"} kb = KeyBundle([desc]) assert len(kb.keys()) == 1 - for k in kb.keys(): + for k in kb.keys(): # noqa kb.mark_as_inactive(k.kid) desc = {"kty": "oct", "key": "highestsupersecret", "use": "enc"} kb.add_jwk_dicts([desc]) @@ -977,7 +977,7 @@ def test_remote_not_modified(): with responses.RequestsMock() as rsps: rsps.add(method="GET", url=source, json=JWKS_DICT, status=200, headers=headers) updated, _ = kb._do_remote() - assert updated == True + assert updated is True assert kb.last_remote == headers.get("Last-Modified") timeout1 = kb.time_out @@ -1020,20 +1020,18 @@ def test_ignore_errors_period(): ignore_errors_period=ignore_errors_period, ) res, _ = kb._do_remote() - assert res == True + assert res is True assert kb.ignore_errors_until is None # refetch, but fail by using a bad source kb.source = source_bad - try: + with contextlib.suppress(UpdateFailed): res, _ = kb._do_remote() - except UpdateFailed: - pass # retry should fail silently as we're in holddown res, _ = kb._do_remote() assert kb.ignore_errors_until is not None - assert res == False + assert res is False # wait until holddown time.sleep(ignore_errors_period + 1) @@ -1041,7 +1039,7 @@ def test_ignore_errors_period(): # try again kb.source = source_good res, _ = kb._do_remote() - assert res == True + assert res is True def test_ignore_invalid_keys(): diff --git a/tests/test_04_key_issuer.py b/tests/test_04_key_issuer.py index 6ff041f1..a9d05bf6 100755 --- a/tests/test_04_key_issuer.py +++ b/tests/test_04_key_issuer.py @@ -208,7 +208,7 @@ def test_build_EC_keyissuer_from_file(tmpdir): assert len(key_issuer) == 2 -class TestKeyJar(object): +class TestKeyJar: def test_keyissuer_add(self): issuer = KeyIssuer() kb = keybundle_from_local_file(RSAKEY, "der", ["ver", "sig"]) @@ -287,7 +287,7 @@ def test_get_enc_not_mine(self): assert issuer.get("enc", "oct") def test_dump_issuer_keys(self): - kb = keybundle_from_local_file("file://%s/jwk.json" % BASE_PATH, "jwks", ["sig"]) + kb = keybundle_from_local_file(f"file://{BASE_PATH}/jwk.json", "jwks", ["sig"]) assert len(kb) == 1 issuer = KeyIssuer() issuer.add_kb(kb) @@ -480,8 +480,8 @@ def test_keyissuer_eq(): assert kj1 == kj2 -PUBLIC_FILE = "{}/public_jwks.json".format(BASEDIR) -PRIVATE_FILE = "{}/private_jwks.json".format(BASEDIR) +PUBLIC_FILE = f"{BASEDIR}/public_jwks.json" +PRIVATE_FILE = f"{BASEDIR}/private_jwks.json" KEYSPEC = [ {"type": "RSA", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -598,17 +598,17 @@ def test_init_key_issuer_update(): OIDC_KEYS = { - "private_path": "{}/priv/jwks.json".format(BASEDIR), + "private_path": f"{BASEDIR}/priv/jwks.json", "key_defs": KEYSPEC, - "public_path": "{}/public/jwks.json".format(BASEDIR), + "public_path": f"{BASEDIR}/public/jwks.json", } def test_init_key_issuer_create_directories(): # make sure the directories are gone for _dir in ["priv", "public"]: - if os.path.isdir("{}/{}".format(BASEDIR, _dir)): - shutil.rmtree("{}/{}".format(BASEDIR, _dir)) + if os.path.isdir(f"{BASEDIR}/{_dir}"): + shutil.rmtree(f"{BASEDIR}/{_dir}") _keyissuer = init_key_issuer(**OIDC_KEYS) assert len(_keyissuer.get("sig", "RSA")) == 1 @@ -617,7 +617,7 @@ def test_init_key_issuer_create_directories(): OIDC_PUB_KEYS = { "key_defs": KEYSPEC, - "public_path": "{}/public/jwks.json".format(BASEDIR), + "public_path": f"{BASEDIR}/public/jwks.json", "read_only": False, } @@ -625,8 +625,8 @@ def test_init_key_issuer_create_directories(): def test_init_key_issuer_public_key_only(): # make sure the directories are gone for _dir in ["public"]: - if os.path.isdir("{}/{}".format(BASEDIR, _dir)): - shutil.rmtree("{}/{}".format(BASEDIR, _dir)) + if os.path.isdir(f"{BASEDIR}/{_dir}"): + shutil.rmtree(f"{BASEDIR}/{_dir}") _keyissuer = init_key_issuer(**OIDC_PUB_KEYS) assert len(_keyissuer.get("sig", "RSA")) == 1 @@ -639,7 +639,7 @@ def test_init_key_issuer_public_key_only(): OIDC_PUB_KEYS2 = { "key_defs": KEYSPEC_3, - "public_path": "{}/public/jwks.json".format(BASEDIR), + "public_path": f"{BASEDIR}/public/jwks.json", "read_only": False, } @@ -647,8 +647,8 @@ def test_init_key_issuer_public_key_only(): def test_init_key_issuer_public_key_only_with_diff(): # make sure the directories are gone for _dir in ["public"]: - if os.path.isdir("{}/{}".format(BASEDIR, _dir)): - shutil.rmtree("{}/{}".format(BASEDIR, _dir)) + if os.path.isdir(f"{BASEDIR}/{_dir}"): + shutil.rmtree(f"{BASEDIR}/{_dir}") _keyissuer = init_key_issuer(**OIDC_PUB_KEYS) assert len(_keyissuer.get("sig", "RSA")) == 1 @@ -701,7 +701,7 @@ def test_localhost_url(): kb = issuer.find(url) assert len(kb) == 1 assert "verify" in kb[0].httpc_params - assert kb[0].httpc_params["verify"] == False + assert kb[0].httpc_params["verify"] is False def test_add_url(): diff --git a/tests/test_04_key_jar.py b/tests/test_04_key_jar.py index c2d9d2dc..71fadd25 100755 --- a/tests/test_04_key_jar.py +++ b/tests/test_04_key_jar.py @@ -215,7 +215,7 @@ def test_build_EC_keyjar_from_file(tmpdir): assert len(key_jar[""]) == 2 -class TestKeyJar(object): +class TestKeyJar: def test_keyjar_add(self): kj = KeyJar() kb = keybundle_from_local_file(RSAKEY, "der", ["ver", "sig"]) @@ -370,7 +370,7 @@ def test_get_enc_not_mine(self): assert ks.get("enc", "oct", "http://www.example.org/") def test_dump_issuer_keys(self): - kb = keybundle_from_local_file("file://%s/jwk.json" % BASE_PATH, "jwks", ["sig"]) + kb = keybundle_from_local_file(f"file://{BASE_PATH}/jwk.json", "jwks", ["sig"]) assert len(kb) == 1 kj = KeyJar() kj.add_kb("", kb) @@ -405,14 +405,11 @@ def test_no_use(self): def test_provider(self): kj = KeyJar() _url = "https://connect-op.herokuapp.com/jwks.json" - kj.load_keys( - "https://connect-op.heroku.com", - jwks_uri=_url, - ) + kj.load_keys("https://connect-op.heroku.com", jwks_uri=_url) iss_keys = kj.get_issuer_keys("https://connect-op.heroku.com") if not iss_keys: - _msg = "{} is not available at this moment!".format(_url) - warnings.warn(_msg) + _msg = f"{_url} is not available at this moment!" + warnings.warn(_msg, stacklevel=1) else: assert iss_keys[0].keys() @@ -628,7 +625,7 @@ def test_keys_by_alg_and_usage(): assert len(k) == 2 -class TestVerifyJWTKeys(object): +class TestVerifyJWTKeys: @pytest.fixture(autouse=True) def setup(self): mkey = [ @@ -789,7 +786,7 @@ def test_str(): kj = KeyJar() kj.add_kb("Alice", KeyBundle(JWK0["keys"])) - desc = "{}".format(kj) + desc = f"{kj}" _cont = json.loads(desc) assert set(_cont.keys()) == {"Alice"} @@ -803,13 +800,13 @@ def test_load_keys(): def test_find(): _path = full_path("jwk_private_key.json") - kb = KeyBundle(source="file://{}".format(_path)) + kb = KeyBundle(source=f"file://{_path}") kj = KeyJar() kj.add_kb("Alice", kb) - assert kj.find("{}".format(_path), "Alice") + assert kj.find(f"{_path}", "Alice") assert kj.find("https://example.com", "Alice") == [] - assert kj.find("{}".format(_path), "Bob") is None + assert kj.find(f"{_path}", "Bob") is None def test_get_decrypt_keys(): @@ -839,7 +836,7 @@ def test_get_decrypt_keys(): def test_update_keyjar(): _path = full_path("jwk_private_key.json") - kb = KeyBundle(source="file://{}".format(_path)) + kb = KeyBundle(source=f"file://{_path}") kj = KeyJar() kj.add_kb("Alice", kb) @@ -856,8 +853,8 @@ def test_key_summary(): assert out == "RSA::abc" -PUBLIC_FILE = "{}/public_jwks.json".format(BASEDIR) -PRIVATE_FILE = "{}/private_jwks.json".format(BASEDIR) +PUBLIC_FILE = f"{BASEDIR}/public_jwks.json" +PRIVATE_FILE = f"{BASEDIR}/private_jwks.json" KEYSPEC = [ {"type": "RSA", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]}, @@ -977,17 +974,17 @@ def test_init_key_jar_update(): OIDC_KEYS = { - "private_path": "{}/priv/jwks.json".format(BASEDIR), + "private_path": f"{BASEDIR}/priv/jwks.json", "key_defs": KEYSPEC, - "public_path": "{}/public/jwks.json".format(BASEDIR), + "public_path": f"{BASEDIR}/public/jwks.json", } def test_init_key_jar_create_directories(): # make sure the directories are gone for _dir in ["priv", "public"]: - if os.path.isdir("{}/{}".format(BASEDIR, _dir)): - shutil.rmtree("{}/{}".format(BASEDIR, _dir)) + if os.path.isdir(f"{BASEDIR}/{_dir}"): + shutil.rmtree(f"{BASEDIR}/{_dir}") _keyjar = init_key_jar(**OIDC_KEYS) assert len(_keyjar.get_signing_key("RSA")) == 1 diff --git a/tests/test_05_jwx.py b/tests/test_05_jwx.py index d394008f..7bfe03e0 100644 --- a/tests/test_05_jwx.py +++ b/tests/test_05_jwx.py @@ -50,7 +50,8 @@ def test_jws_set_jku(): def test_jwx_set_x5c(): - jwx = JWx(x5c=open(full_path("cert.pem")).read()) + with open(full_path("cert.pem")) as fp: + jwx = JWx(x5c=fp.read()) keys = jwx._get_keys() assert len(keys) assert isinstance(keys[0], RSAKey) diff --git a/tests/test_06_jws.py b/tests/test_06_jws.py index 4769f3c1..92751569 100644 --- a/tests/test_06_jws.py +++ b/tests/test_06_jws.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import json import os.path @@ -574,7 +572,7 @@ def test_signer_ps256_fail(): except BadSignature: pass else: - assert False + raise AssertionError def test_signer_ps384(): @@ -647,11 +645,11 @@ def test_signer_eddsa_fail(): _pubkey = OKPKey().load_key(okp2.public_key()) _rj = JWS(alg="Ed25519") try: - info = _rj.verify_compact(_jwt, [_pubkey]) + _ = _rj.verify_compact(_jwt, [_pubkey]) except BadSignature: pass else: - assert False + raise AssertionError def test_no_alg_and_alg_none_same(): @@ -906,7 +904,7 @@ def test_rs256_rm_signature(): except WrongNumberOfParts: pass else: - assert False + raise AssertionError def test_pick_alg_assume_alg_from_single_key(): diff --git a/tests/test_07_jwe.py b/tests/test_07_jwe.py index 7a87efce..9f4f5e8a 100644 --- a/tests/test_07_jwe.py +++ b/tests/test_07_jwe.py @@ -62,10 +62,7 @@ def str2intarr(string): return array.array("B", string).tolist() -if sys.version < "3": - to_intarr = str2intarr -else: - to_intarr = bytes2intarr +to_intarr = str2intarr if sys.version < "3" else bytes2intarr def test_jwe_09_a1(): @@ -671,9 +668,9 @@ def test_fernet_symkey(): def test_fernet_bad(): with pytest.raises(TypeError): - encrypter = FernetEncrypter(key="xyzzy") + _ = FernetEncrypter(key="xyzzy") with pytest.raises(ValueError): - encrypter = FernetEncrypter(key=os.urandom(16)) + _ = FernetEncrypter(key=os.urandom(16)) def test_fernet_bytes(): diff --git a/tests/test_09_jwt.py b/tests/test_09_jwt.py index 452ee8c2..1fa43db0 100755 --- a/tests/test_09_jwt.py +++ b/tests/test_09_jwt.py @@ -156,7 +156,7 @@ def test_jwt_pack_and_unpack_unknown_key(): kj.add_kb(ALICE, KeyBundle()) bob = JWT(key_jar=kj, iss=BOB, allowed_sign_algs=["RS256"]) with pytest.raises(NoSuitableSigningKeys): - info = bob.unpack(_jwt) + _ = bob.unpack(_jwt) def test_jwt_pack_and_unpack_with_lifetime(): @@ -240,7 +240,7 @@ def test_with_jti(): assert "jti" in info -class DummyMsg(object): +class DummyMsg: def __init__(self, **kwargs): for key, val in kwargs.items(): setattr(self, key, val) diff --git a/tests/test_30_tools.py b/tests/test_30_tools.py index 2eb46feb..0a6a0cdd 100644 --- a/tests/test_30_tools.py +++ b/tests/test_30_tools.py @@ -10,7 +10,7 @@ def jwk_from_file(filename: str, private: bool = True) -> JWK: """Read JWK from file""" - with open(filename, mode="rt") as input_file: + with open(filename) as input_file: jwk_dict = json.loads(input_file.read()) return key_from_jwk_dict(jwk_dict, private=private) diff --git a/tests/test_31_utils.py b/tests/test_31_utils.py index d403a40d..86e98a59 100644 --- a/tests/test_31_utils.py +++ b/tests/test_31_utils.py @@ -2,16 +2,16 @@ def test_check_content_type(): - assert check_content_type(content_type="application/json", mime_type="application/json") == True + assert check_content_type(content_type="application/json", mime_type="application/json") is True assert ( check_content_type( content_type="application/json; charset=utf-8", mime_type="application/json" ) - == True + is True ) assert ( check_content_type( content_type="application/html; charset=utf-8", mime_type="application/json" ) - == False + is False ) diff --git a/tests/test_40_serialize.py b/tests/test_40_serialize.py index f83b9f6a..6f3ec6c0 100644 --- a/tests/test_40_serialize.py +++ b/tests/test_40_serialize.py @@ -15,7 +15,7 @@ def full_path(local_file): def test_key_issuer(): - kb = keybundle_from_local_file("file://%s/jwk.json" % BASE_PATH, "jwks", ["sig"]) + kb = keybundle_from_local_file(f"file://{BASE_PATH}/jwk.json", "jwks", ["sig"]) assert len(kb) == 1 issuer = KeyIssuer() issuer.add(kb) diff --git a/tests/test_50_argument_alias.py b/tests/test_50_argument_alias.py index 8831fb14..16070101 100644 --- a/tests/test_50_argument_alias.py +++ b/tests/test_50_argument_alias.py @@ -46,7 +46,7 @@ def full_path(local_file): ] -class TestVerifyJWTKeys(object): +class TestVerifyJWTKeys: @pytest.fixture(autouse=True) def setup(self): mkey = [ @@ -104,8 +104,8 @@ def test_aud(self): assert len(keys) == 1 -PUBLIC_FILE = "{}/public_jwks.json".format(BASEDIR) -PRIVATE_FILE = "{}/private_jwks.json".format(BASEDIR) +PUBLIC_FILE = f"{BASEDIR}/public_jwks.json" +PRIVATE_FILE = f"{BASEDIR}/private_jwks.json" KEYSPEC = [ {"type": "RSA", "use": ["sig"]}, {"type": "EC", "crv": "P-256", "use": ["sig"]},