Skip to content

Commit 8ab6c8e

Browse files
authored
Merge pull request #112 from bflad/bflad/authenticate-requests-httpx-retries
feat: Implement retries and caching in JWKS helpers
2 parents 57afe66 + 71d1f86 commit 8ab6c8e

File tree

2 files changed

+53
-7
lines changed

2 files changed

+53
-7
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import time
2+
3+
class Cache:
4+
""" In-memory cache with expiration. """
5+
6+
def __init__(self):
7+
self.cache = {}
8+
self.expiration_time = 300 # 5 minutes
9+
10+
def set(self, key: str | None, value: str):
11+
if key is None:
12+
return
13+
14+
self.cache[key] = (value, time.time() + self.expiration_time)
15+
16+
def get(self, key: str | None) -> str | None:
17+
if key is None:
18+
return None
19+
20+
if key in self.cache:
21+
value, expiration = self.cache[key]
22+
23+
if time.time() < expiration:
24+
return value
25+
26+
del self.cache[key]
27+
28+
return None

src/clerk_backend_api/jwks_helpers/verifytoken.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from enum import Enum
1010
from typing import Any, Dict, List, Union, Optional, cast
1111

12+
from .cache import Cache
13+
14+
__jwkcache = Cache()
15+
1216

1317
class TokenVerificationErrorReason(Enum):
1418

@@ -110,12 +114,20 @@ def fetch_jwks(options: VerifyTokenOptions) -> Dict[str, Any]:
110114
""" Fetch JWKS from Clerk's Backend API."""
111115

112116
jwks_url = f'{options.api_url}/{options.api_version}/jwks'
113-
with httpx.Client() as client:
114-
http_res = client.get(jwks_url, headers={
115-
'Accept': 'application/json', 'Authorization': f'Bearer {options.secret_key}'
116-
})
117-
118-
if http_res.status_code != 200:
117+
transport = httpx.HTTPTransport(retries=10) # handles ConnectError and ConnectTimeout
118+
with httpx.Client(transport=transport) as client:
119+
http_res = None
120+
121+
for _ in range(10):
122+
try:
123+
http_res = client.get(jwks_url, headers={
124+
'Accept': 'application/json', 'Authorization': f'Bearer {options.secret_key}'
125+
})
126+
except httpx.TimeoutException:
127+
continue
128+
break
129+
130+
if http_res is None or http_res.status_code != 200:
119131
raise TokenVerificationError(TokenVerificationErrorReason.JWK_FAILED_TO_LOAD)
120132

121133
try:
@@ -136,6 +148,10 @@ def get_remote_jwt_key(token: str, options: VerifyTokenOptions) -> str:
136148
except jwt.InvalidTokenError as e:
137149
raise TokenVerificationError(TokenVerificationErrorReason.TOKEN_INVALID) from e
138150

151+
decoded_pem = __jwkcache.get(kid)
152+
if decoded_pem is not None:
153+
return decoded_pem
154+
139155
jwks = fetch_jwks(options).get('keys')
140156
if jwks is None:
141157
raise TokenVerificationError(TokenVerificationErrorReason.JWK_REMOTE_INVALID)
@@ -149,7 +165,9 @@ def get_remote_jwt_key(token: str, options: VerifyTokenOptions) -> str:
149165
encoding=serialization.Encoding.PEM,
150166
format=serialization.PublicFormat.SubjectPublicKeyInfo
151167
)
152-
return pem.decode('utf-8')
168+
decoded_pem = pem.decode('utf-8')
169+
__jwkcache.set(kid, decoded_pem)
170+
return decoded_pem
153171

154172
raise TokenVerificationError(TokenVerificationErrorReason.JWK_KID_MISMATCH)
155173

0 commit comments

Comments
 (0)