Skip to content

Commit b10fa7a

Browse files
lpsingerjak574
andauthored
Decode access token in a type-safe, low boilerplate way (nasa-gcn#1906)
* Decode access token in a type-safe, low boilerplate way - Instead of inheriting from HTTPBearer, use it as a dependency. - This allows us to return something other than the credential, namely: the verified and decoded claims. - Factor out the issuer and JWKS discovery as dependency functions that have no other dependencies so that these can be shared across multiple requests. * Remove unused import of TypedDict --------- Co-authored-by: Jamie Kennea <[email protected]>
1 parent 7d5a1e8 commit b10fa7a

File tree

1 file changed

+61
-67
lines changed

1 file changed

+61
-67
lines changed

python/across_api/auth/api.py

Lines changed: 61 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,86 +4,80 @@
44

55

66
import os
7-
from typing import Optional
7+
from typing import Annotated, Any
88

99
from jose import jwt
1010
from jose.exceptions import JWTError
1111
import httpx # type: ignore
12-
from fastapi import Depends, HTTPException, Request
12+
from fastapi import Depends, HTTPException
1313
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
1414

1515
from ..base.api import app
1616
from .schema import VerifyAuth
1717

1818

19-
class JWTBearer(HTTPBearer):
20-
def __init__(self, **kwargs):
21-
# Configure URL from IdP
22-
user_pool_id = os.environ.get("COGNITO_USER_POOL_ID")
23-
if user_pool_id is not None:
24-
self.provider_url = f"https://cognito-idp.{user_pool_id.split('_')[0]}.amazonaws.com/{user_pool_id}/"
25-
elif os.environ.get("ARC_ENV") == "testing":
26-
self.provider_url = (
27-
f"http://localhost:{os.environ.get('ARC_OIDC_IDP_PORT')}/"
28-
)
29-
else:
30-
raise RuntimeError(
31-
"Environment variable COGNITO_USER_POOL_ID must be defined in production.",
32-
)
33-
34-
super().__init__(**kwargs)
35-
36-
async def __call__(self, request: Request) -> None:
37-
credentials: Optional[HTTPAuthorizationCredentials] = await super().__call__(
38-
request
39-
)
40-
"""Validate credentials if passed"""
41-
if credentials:
42-
# Fetch the well-known config from the IdP
43-
async with httpx.AsyncClient() as client:
44-
resp = await client.get(
45-
self.provider_url + ".well-known/openid-configuration"
46-
)
47-
resp.raise_for_status()
48-
49-
# Find the jwks_uri and token algorithms from the well-known config
50-
well_known = resp.json()
51-
jwks_uri = well_known["jwks_uri"]
52-
token_alg = well_known["id_token_signing_alg_values_supported"]
53-
54-
# Fetch signing key from Cognito
55-
async with httpx.AsyncClient() as client:
56-
resp = await client.get(jwks_uri)
57-
jwks_data = resp.json()
58-
header = jwt.get_unverified_header(credentials.credentials)
59-
for signing_key in jwks_data["keys"]:
60-
if signing_key["kid"] == header["kid"]:
61-
break
62-
else:
63-
raise HTTPException(
64-
status_code=401, detail="Authentication error: Invalid key."
65-
)
66-
67-
# Validate the credentials
68-
try:
69-
jwt.decode(
70-
credentials.credentials,
71-
key=signing_key,
72-
algorithms=token_alg,
73-
)
74-
except JWTError as e:
75-
raise HTTPException(
76-
status_code=401, detail=f"Authentication error: {e}"
77-
)
78-
else:
79-
raise AssertionError("No credentials passed.")
80-
81-
82-
security = JWTBearer(
19+
bearer = HTTPBearer(
8320
scheme_name="ACROSS API Authorization",
8421
description="Enter your access token.",
22+
auto_error=True,
8523
)
86-
JWTBearerDep = [Depends(security)]
24+
25+
26+
async def issuer() -> dict[str, Any]:
27+
"""Discover OpenID Connect issuer configuration."""
28+
# Configure URL from IdP
29+
user_pool_id = os.environ.get("COGNITO_USER_POOL_ID")
30+
if user_pool_id is not None:
31+
provider_url = f"https://cognito-idp.{user_pool_id.split('_')[0]}.amazonaws.com/{user_pool_id}/"
32+
elif os.environ.get("ARC_ENV") == "testing":
33+
provider_url = f"http://localhost:{os.environ.get('ARC_OIDC_IDP_PORT')}/"
34+
else:
35+
raise RuntimeError(
36+
"Environment variable COGNITO_USER_POOL_ID must be defined in production.",
37+
)
38+
39+
# Fetch the well-known config from the IdP
40+
async with httpx.AsyncClient() as client:
41+
resp = await client.get(f"{provider_url}.well-known/openid-configuration")
42+
resp.raise_for_status()
43+
return resp.json()
44+
45+
46+
async def jwks(issuer: Annotated[dict[str, Any], Depends(issuer)]) -> dict[str, Any]:
47+
"""Fetch JSON Web Key signature set from the OpenID Connect issuer."""
48+
async with httpx.AsyncClient() as client:
49+
resp = await client.get(issuer["jwks_uri"])
50+
resp.raise_for_status()
51+
return resp.json()
52+
53+
54+
async def claims(
55+
credentials: Annotated[HTTPAuthorizationCredentials, Depends(bearer)],
56+
issuer: Annotated[dict[str, Any], Depends(issuer)],
57+
jwks: Annotated[dict[str, Any], Depends(jwks)],
58+
) -> dict[str, Any]:
59+
"""Verify and return the claims in the request's access token."""
60+
header = jwt.get_unverified_header(credentials.credentials)
61+
for signing_key in jwks["keys"]:
62+
if signing_key["kid"] == header["kid"]:
63+
break
64+
else:
65+
raise HTTPException(
66+
status_code=401, detail="Authentication error: Invalid key."
67+
)
68+
69+
# Validate the credentials
70+
try:
71+
return jwt.decode(
72+
credentials.credentials,
73+
key=signing_key,
74+
algorithms=issuer["id_token_signing_alg_values_supported"],
75+
)
76+
except JWTError as e:
77+
raise HTTPException(status_code=401, detail=f"Authentication error: {e}")
78+
79+
80+
JWTBearerDep = [Depends(claims)]
8781

8882

8983
@app.get("/auth/verify", dependencies=JWTBearerDep)

0 commit comments

Comments
 (0)