-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathjwk.py
220 lines (181 loc) · 7.72 KB
/
jwk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import copy
import json
import os
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric import ed448
from cryptography.hazmat.primitives.asymmetric import ed25519
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmp1
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_dmq1
from cryptography.hazmat.primitives.asymmetric.rsa import rsa_crt_iqmp
from ..exception import MissingValue
from ..exception import UnknownKeyType
from ..exception import UnsupportedAlgorithm
from ..exception import WrongKeyType
from ..utils import base64url_to_long
from .ec import NIST2SEC
from .ec import ECKey
from .hmac import SYMKey
from .okp import OKPKey
from .rsa import RSAKey
EC_PUBLIC_REQUIRED = frozenset(["crv", "x", "y"])
EC_PUBLIC = EC_PUBLIC_REQUIRED
EC_PRIVATE_REQUIRED = frozenset(["d"])
EC_PRIVATE_OPTIONAL = frozenset()
EC_PRIVATE = EC_PRIVATE_REQUIRED | EC_PRIVATE_OPTIONAL
OKP_PUBLIC_REQUIRED = frozenset(["crv", "x"])
OKP_PUBLIC = OKP_PUBLIC_REQUIRED
OKP_PRIVATE_REQUIRED = frozenset(["d"])
OKP_PRIVATE_OPTIONAL = frozenset()
OKP_PRIVATE = OKP_PRIVATE_REQUIRED | OKP_PRIVATE_OPTIONAL
RSA_PUBLIC_REQUIRED = frozenset(["e", "n"])
RSA_PUBLIC = RSA_PUBLIC_REQUIRED
RSA_PRIVATE_REQUIRED = frozenset(["p", "q", "d"])
RSA_PRIVATE_OPTIONAL = frozenset(["qi", "dp", "dq"])
RSA_PRIVATE = RSA_PRIVATE_REQUIRED | RSA_PRIVATE_OPTIONAL
def ensure_ec_params(jwk_dict, private):
"""Ensure all required EC parameters are present in dictionary"""
provided = frozenset(jwk_dict.keys())
if private is not None and private:
required = EC_PUBLIC_REQUIRED | EC_PRIVATE_REQUIRED
else:
required = EC_PUBLIC_REQUIRED
return ensure_params("EC", provided, required)
def ensure_okp_params(jwk_dict, private):
"""Ensure all required OKP parameters are present in dictionary"""
provided = frozenset(jwk_dict.keys())
if private is not None and private:
required = OKP_PUBLIC_REQUIRED | OKP_PRIVATE_REQUIRED
else:
required = OKP_PUBLIC_REQUIRED
return ensure_params("OKP", provided, required)
def ensure_rsa_params(jwk_dict, private):
"""Ensure all required RSA parameters are present in dictionary"""
provided = frozenset(jwk_dict.keys())
if private is not None and private:
required = RSA_PUBLIC_REQUIRED | RSA_PRIVATE_REQUIRED
else:
required = RSA_PUBLIC_REQUIRED
return ensure_params("RSA", provided, required)
def ensure_params(kty, provided, required):
"""Ensure all required parameters are present in dictionary"""
if not required <= provided:
missing = required - provided
raise MissingValue("Missing properties for kty={}, {}".format(kty, str(list(missing))))
def key_from_jwk_dict(jwk_dict, private=None):
"""Load JWK from dictionary
:param jwk_dict: Dictionary representing a JWK
"""
# uncouple from the original item
_jwk_dict = copy.deepcopy(jwk_dict)
if "kty" not in _jwk_dict:
raise MissingValue("kty missing")
if _jwk_dict["kty"] == "EC":
ensure_ec_params(_jwk_dict, private)
if private is not None and not private:
# remove private components
for v in EC_PRIVATE:
_jwk_dict.pop(v, None)
if _jwk_dict["crv"] in NIST2SEC:
curve = NIST2SEC[_jwk_dict["crv"]]()
else:
raise UnsupportedAlgorithm("Unknown curve: %s" % (_jwk_dict["crv"]))
if _jwk_dict.get("d", None) is not None:
# Ecdsa private key.
_jwk_dict["priv_key"] = ec.derive_private_key(base64url_to_long(_jwk_dict["d"]), curve)
_jwk_dict["pub_key"] = _jwk_dict["priv_key"].public_key()
else:
# Ecdsa public key.
ec_pub_numbers = ec.EllipticCurvePublicNumbers(
base64url_to_long(_jwk_dict["x"]),
base64url_to_long(_jwk_dict["y"]),
curve,
)
_jwk_dict["pub_key"] = ec_pub_numbers.public_key()
return ECKey(**_jwk_dict)
elif _jwk_dict["kty"] == "RSA":
ensure_rsa_params(_jwk_dict, private)
if private is not None and not private:
# remove private components
for v in RSA_PRIVATE:
_jwk_dict.pop(v, None)
rsa_pub_numbers = rsa.RSAPublicNumbers(
base64url_to_long(_jwk_dict["e"]), base64url_to_long(_jwk_dict["n"])
)
if _jwk_dict.get("p", None) is not None:
# Rsa private key. These MUST be present
p_long = base64url_to_long(_jwk_dict["p"])
q_long = base64url_to_long(_jwk_dict["q"])
d_long = base64url_to_long(_jwk_dict["d"])
# If not present these can be calculated from the others
if "dp" not in _jwk_dict:
dp_long = rsa_crt_dmp1(d_long, p_long)
else:
dp_long = base64url_to_long(_jwk_dict["dp"])
if "dq" not in _jwk_dict:
dq_long = rsa_crt_dmq1(d_long, q_long)
else:
dq_long = base64url_to_long(_jwk_dict["dq"])
if "qi" not in _jwk_dict:
qi_long = rsa_crt_iqmp(p_long, q_long)
else:
qi_long = base64url_to_long(_jwk_dict["qi"])
rsa_priv_numbers = rsa.RSAPrivateNumbers(
p_long, q_long, d_long, dp_long, dq_long, qi_long, rsa_pub_numbers
)
_jwk_dict["priv_key"] = rsa_priv_numbers.private_key()
_jwk_dict["pub_key"] = _jwk_dict["priv_key"].public_key()
else:
_jwk_dict["pub_key"] = rsa_pub_numbers.public_key()
if _jwk_dict["kty"] != "RSA":
raise WrongKeyType('"{}" should have been "RSA"'.format(_jwk_dict["kty"]))
return RSAKey(**_jwk_dict)
elif _jwk_dict["kty"] == "OKP":
ensure_okp_params(_jwk_dict, private)
if private is not None and not private:
# remove private components
for v in OKP_PRIVATE:
_jwk_dict.pop(v, None)
return OKPKey(**_jwk_dict)
elif _jwk_dict["kty"] == "oct":
if "key" not in _jwk_dict and "k" not in _jwk_dict:
raise MissingValue('There has to be one of "k" or "key" in a symmetric key')
return SYMKey(**_jwk_dict)
else:
raise UnknownKeyType
def jwk_wrap(key, use="", kid=""):
"""
Instantiate a Key instance with the given key
:param key: The keys to wrap
:param use: What the key are expected to be use for
:param kid: A key id
:return: The Key instance
"""
if isinstance(key, rsa.RSAPublicKey) or isinstance(key, rsa.RSAPrivateKey):
kspec = RSAKey(use=use, kid=kid).load_key(key)
elif isinstance(key, str):
kspec = SYMKey(key=key, use=use, kid=kid)
elif isinstance(key, ec.EllipticCurvePublicKey):
kspec = ECKey(use=use, kid=kid).load_key(key)
elif isinstance(key, (ed25519.Ed25519PublicKey, ed448.Ed448PublicKey)):
kspec = OKPKey(use=use, kid=kid).load_key(key)
else:
raise Exception("Unknown key type:key=" + str(type(key)))
if not kspec.kid:
kspec.add_kid()
kspec.serialize()
return kspec
def dump_jwk(filename, key):
"""Writes a RSAKey, ECKey or SYMKey instance as a JWK to a file."""
head, tail = os.path.split(filename)
if head and not os.path.isdir(head):
os.makedirs(head)
with open(filename, "w") as fp:
fp.write(json.dumps(key.to_dict()))
def import_jwk(filename):
"""Reads a JWK from a file and converts it into the appropriate key class instance."""
if os.path.isfile(filename):
with open(filename) as jwk_file:
jwk_dict = json.loads(jwk_file.read())
return key_from_jwk_dict(jwk_dict)
return None