diff --git a/poetry.lock b/poetry.lock index 3c6c549e..71582b30 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1003,13 +1003,13 @@ urllib3 = ">=1.26.0" [[package]] name = "pyjwt" -version = "2.9.0" +version = "2.10.1" description = "JSON Web Token implementation in Python" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"}, - {file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"}, + {file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"}, + {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, ] [package.dependencies] @@ -1444,4 +1444,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4330f09b4fdf6c3c2ae616374a7e0d5d644d4f2338b4e20697f30727fb90acc5" +content-hash = "cf94465a96161855010f04eff511b178ff2d1ba53398bf1674a15b2ca7c827da" diff --git a/pyproject.toml b/pyproject.toml index 1ba68276..e75a5eb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ python = "^3.9" httpx = {version = ">=0.26,<0.29", extras = ["http2"]} pydantic = ">=1.10,<3" +pyjwt = "^2.10.1" [tool.poetry.dev-dependencies] pytest = "^8.3.4" diff --git a/supabase_auth/_async/gotrue_client.py b/supabase_auth/_async/gotrue_client.py index 2a2a5564..4069991b 100644 --- a/supabase_auth/_async/gotrue_client.py +++ b/supabase_auth/_async/gotrue_client.py @@ -4,7 +4,7 @@ from functools import partial from json import loads from time import time -from typing import Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 @@ -20,11 +20,12 @@ AuthApiError, AuthImplicitGrantRedirectError, AuthInvalidCredentialsError, + AuthInvalidJwtError, AuthRetryableError, AuthSessionMissingError, ) from ..helpers import ( - decode_jwt_payload, + decode_jwt, generate_pkce_challenge, generate_pkce_verifier, model_dump, @@ -39,6 +40,8 @@ from ..http_clients import AsyncClient from ..timer import Timer from ..types import ( + JWK, + JWKS, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -106,6 +109,7 @@ def __init__( verify=verify, proxy=proxy, ) + self._jwks: JWKS = {} self._storage_key = storage_key or STORAGE_KEY self._auto_refresh_token = auto_refresh_token self._persist_session = persist_session @@ -1128,7 +1132,8 @@ def _decode_jwt(self, jwt: str) -> DecodedJWTDict: """ Decodes a JWT (without performing any validation). """ - return decode_jwt_payload(jwt) + decoded = decode_jwt(jwt) + return decoded["payload"] async def exchange_code_for_session(self, params: CodeExchangeParams): code_verifier = params.get("code_verifier") or await self._storage.get_item( @@ -1150,3 +1155,41 @@ async def exchange_code_for_session(self, params: CodeExchangeParams): await self._save_session(response.session) self._notify_all_subscribers("SIGNED_IN", response.session) return response + + async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK: + # try fetching from the suplied keys. + jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None) + + if jwk: + return jwk + + # try fetching from the cache. + jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None) + if jwk: + return jwk + + # jwk isn't cached in memory so we need to fetch it from the well-known endpoint + response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks) + if response.jwks: + self._jwks = response.jwks + + # find the signing key + jwk = next( + (jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None + ) + if not jwk: + raise AuthInvalidJwtError("No matching signing key found in JWKS") + + return jwk + + raise AuthInvalidJwtError("JWT has no valid kid") + + async def get_claims(self): + pass + + +def parse_jwks(response: Any) -> JWKS: + if "keys" not in response or len(response.keys) == 0: + raise AuthInvalidJwtError("JWKS is empty") + + return JWKS(keys=response.keys) diff --git a/supabase_auth/errors.py b/supabase_auth/errors.py index fa693894..1cf8de3c 100644 --- a/supabase_auth/errors.py +++ b/supabase_auth/errors.py @@ -225,3 +225,11 @@ def to_dict(self) -> AuthApiErrorDict: "status": self.status, "reasons": self.reasons, } + +class AuthInvalidJwtError(CustomAuthError): + def __init__(self, message: str) -> None: + CustomAuthError.__init__( + self, + message, + "AuthInvalidJwtError", + "invalid_jwt", \ No newline at end of file diff --git a/supabase_auth/helpers.py b/supabase_auth/helpers.py index ae9dc7c5..8be212e2 100644 --- a/supabase_auth/helpers.py +++ b/supabase_auth/helpers.py @@ -8,16 +8,19 @@ from base64 import urlsafe_b64decode from datetime import datetime from json import loads -from typing import Any, Dict, Optional, Type, TypeVar, cast +from typing import Any, Callable, Dict, Literal, Optional, Type, TypeVar, cast from urllib.parse import urlparse from httpx import HTTPStatusError, Response +import jwt +import jwt.algorithms from pydantic import BaseModel from .constants import API_VERSION_HEADER_NAME, API_VERSIONS from .errors import ( AuthApiError, AuthError, + AuthInvalidJwtError, AuthRetryableError, AuthUnknownError, AuthWeakPasswordError, @@ -192,15 +195,38 @@ def handle_exception(exception: Exception) -> AuthError: return AuthUnknownError(get_error_message(error), e) -def decode_jwt_payload(token: str) -> Any: - parts = token.split(".") - if len(parts) != 3: - raise ValueError("JWT is not valid: not a JWT structure") - base64url = parts[1] +def str_from_base64url(base64url: str) -> str: # Addding padding otherwise the following error happens: # binascii.Error: Incorrect padding base64url_with_padding = base64url + "=" * (-len(base64url) % 4) - return loads(urlsafe_b64decode(base64url_with_padding).decode("utf-8")) + return urlsafe_b64decode(base64url_with_padding).decode("utf-8") + + +def base64url_to_bytes(base64url: str) -> bytes: + # Addding padding otherwise the following error happens: + # binascii.Error: Incorrect padding + base64url_with_padding = base64url + "=" * (-len(base64url) % 4) + return urlsafe_b64decode(base64url_with_padding) + + +def decode_jwt(token: str) -> Dict[str, Any]: + parts = token.split(".") + if len(parts) != 3: + raise AuthInvalidJwtError("Invalid JWT structure") + + # regex check for base64url + if not re.match(BASE64URL_REGEX, parts[1]): + raise AuthInvalidJwtError("JWT not in base64url format") + + return { + "header": loads(str_from_base64url(parts[0])), + "payload": loads(str_from_base64url(parts[1])), + "signature": base64url_to_bytes(parts[2]), + "raw": { + "header": parts[0], + "payload": parts[1], + }, + } def generate_pkce_verifier(length=64): @@ -267,3 +293,19 @@ def is_valid_jwt(value: str) -> bool: return False return True + + +def validate_exp(exp: int) -> None: + if not exp: + raise AuthInvalidJwtError("JWT has no expiration time") + + time_now = datetime.now().timestamp() + if exp <= time_now: + raise AuthInvalidJwtError("JWT has expired") + + +def get_algorithm(alg: Literal["RS256", "ES256"]) -> jwt.algorithms.Algorithm: + if alg == "RS256": + return jwt.algorithms.RSAAlgorithm + elif alg == "ES256": + return jwt.algorithms.ECAlgorithm diff --git a/supabase_auth/types.py b/supabase_auth/types.py index 86bda3e2..dd499aa0 100644 --- a/supabase_auth/types.py +++ b/supabase_auth/types.py @@ -789,6 +789,36 @@ class SignOutOptions(TypedDict): scope: NotRequired[SignOutScope] +class JWTHeader(TypedDict): + alg: Literal["RS256", "ES256", "HS256"] + typ: str + kid: str + + +class RequiredClaims(TypedDict): + iss: str + sub: str + auth: Union[str, List[str]] + exp: int + iat: int + role: str + aal: AuthenticatorAssuranceLevels + session_id: str + + +class JWTPayload(RequiredClaims, TypedDict, total=False): + pass + + +class JWK(TypedDict, total=False): + kty: Literal["RSA", "EC", "oct"] + key_ops: List[str] + alg: Optional[str] + kid: Optional[str] + +class JWKS(TypedDict): + keys: List[JWK] + for model in [ AMREntry, AuthResponse,