9
9
from enum import Enum
10
10
from typing import Any , Dict , List , Union , Optional , cast
11
11
12
+ from .cache import Cache
13
+
14
+ __jwkcache = Cache ()
15
+
12
16
13
17
class TokenVerificationErrorReason (Enum ):
14
18
@@ -110,12 +114,20 @@ def fetch_jwks(options: VerifyTokenOptions) -> Dict[str, Any]:
110
114
""" Fetch JWKS from Clerk's Backend API."""
111
115
112
116
jwks_url = f'{ options .api_url } /{ options .api_version } /jwks'
113
- with httpx .Client () as client :
114
- http_res = client .get (jwks_url , headers = {
115
- 'Accept' : 'application/json' , 'Authorization' : f'Bearer { options .secret_key } '
116
- })
117
-
118
- if http_res .status_code != 200 :
117
+ transport = httpx .HTTPTransport (retries = 10 ) # handles ConnectError and ConnectTimeout
118
+ with httpx .Client (transport = transport ) as client :
119
+ http_res = None
120
+
121
+ for _ in range (10 ):
122
+ try :
123
+ http_res = client .get (jwks_url , headers = {
124
+ 'Accept' : 'application/json' , 'Authorization' : f'Bearer { options .secret_key } '
125
+ })
126
+ except httpx .TimeoutException :
127
+ continue
128
+ break
129
+
130
+ if http_res is None or http_res .status_code != 200 :
119
131
raise TokenVerificationError (TokenVerificationErrorReason .JWK_FAILED_TO_LOAD )
120
132
121
133
try :
@@ -136,6 +148,10 @@ def get_remote_jwt_key(token: str, options: VerifyTokenOptions) -> str:
136
148
except jwt .InvalidTokenError as e :
137
149
raise TokenVerificationError (TokenVerificationErrorReason .TOKEN_INVALID ) from e
138
150
151
+ decoded_pem = __jwkcache .get (kid )
152
+ if decoded_pem is not None :
153
+ return decoded_pem
154
+
139
155
jwks = fetch_jwks (options ).get ('keys' )
140
156
if jwks is None :
141
157
raise TokenVerificationError (TokenVerificationErrorReason .JWK_REMOTE_INVALID )
@@ -149,7 +165,9 @@ def get_remote_jwt_key(token: str, options: VerifyTokenOptions) -> str:
149
165
encoding = serialization .Encoding .PEM ,
150
166
format = serialization .PublicFormat .SubjectPublicKeyInfo
151
167
)
152
- return pem .decode ('utf-8' )
168
+ decoded_pem = pem .decode ('utf-8' )
169
+ __jwkcache .set (kid , decoded_pem )
170
+ return decoded_pem
153
171
154
172
raise TokenVerificationError (TokenVerificationErrorReason .JWK_KID_MISMATCH )
155
173
0 commit comments