Skip to content

Commit

Permalink
feat: add get_claims method
Browse files Browse the repository at this point in the history
  • Loading branch information
grdsdev committed Feb 25, 2025
1 parent 24b9843 commit 33f00b5
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 16 deletions.
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ classifiers = [
python = "^3.9"
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
pydantic = ">=1.10,<3"
pyjwt = "^2.10.1"

[tool.poetry.dev-dependencies]
pytest = "^8.3.4"
Expand Down
49 changes: 46 additions & 3 deletions supabase_auth/_async/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from json import loads
from time import time
from typing import Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse
from uuid import uuid4

Expand All @@ -20,11 +20,12 @@
AuthApiError,
AuthImplicitGrantRedirectError,
AuthInvalidCredentialsError,
AuthInvalidJwtError,
AuthRetryableError,
AuthSessionMissingError,
)
from ..helpers import (
decode_jwt_payload,
decode_jwt,
generate_pkce_challenge,
generate_pkce_verifier,
model_dump,
Expand All @@ -39,6 +40,8 @@
from ..http_clients import AsyncClient
from ..timer import Timer
from ..types import (
JWK,
JWKS,
AuthChangeEvent,
AuthenticatorAssuranceLevels,
AuthFlowType,
Expand Down Expand Up @@ -106,6 +109,7 @@ def __init__(
verify=verify,
proxy=proxy,
)
self._jwks: JWKS = {}
self._storage_key = storage_key or STORAGE_KEY
self._auto_refresh_token = auto_refresh_token
self._persist_session = persist_session
Expand Down Expand Up @@ -1128,7 +1132,8 @@ def _decode_jwt(self, jwt: str) -> DecodedJWTDict:
"""
Decodes a JWT (without performing any validation).
"""
return decode_jwt_payload(jwt)
decoded = decode_jwt(jwt)
return decoded["payload"]

async def exchange_code_for_session(self, params: CodeExchangeParams):
code_verifier = params.get("code_verifier") or await self._storage.get_item(
Expand All @@ -1150,3 +1155,41 @@ async def exchange_code_for_session(self, params: CodeExchangeParams):
await self._save_session(response.session)
self._notify_all_subscribers("SIGNED_IN", response.session)
return response

async def _fetch_jwks(self, kid: str, jwks: JWKS) -> JWK:
# try fetching from the suplied keys.
jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None)

if jwk:
return jwk

# try fetching from the cache.
jwk = next((jwk for jwk in self._jwks["keys"] if jwk["kid"] == kid), None)
if jwk:
return jwk

# jwk isn't cached in memory so we need to fetch it from the well-known endpoint
response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks)
if response.jwks:
self._jwks = response.jwks

# find the signing key
jwk = next(
(jwk for jwk in response.jwks["keys"] if jwk["kid"] == kid), None
)
if not jwk:
raise AuthInvalidJwtError("No matching signing key found in JWKS")

return jwk

raise AuthInvalidJwtError("JWT has no valid kid")

async def get_claims(self):
pass


def parse_jwks(response: Any) -> JWKS:
if "keys" not in response or len(response.keys) == 0:
raise AuthInvalidJwtError("JWKS is empty")

return JWKS(keys=response.keys)
8 changes: 8 additions & 0 deletions supabase_auth/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,11 @@ def to_dict(self) -> AuthApiErrorDict:
"status": self.status,
"reasons": self.reasons,
}

class AuthInvalidJwtError(CustomAuthError):
def __init__(self, message: str) -> None:
CustomAuthError.__init__(
self,
message,
"AuthInvalidJwtError",
"invalid_jwt",
56 changes: 49 additions & 7 deletions supabase_auth/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@
from base64 import urlsafe_b64decode
from datetime import datetime
from json import loads
from typing import Any, Dict, Optional, Type, TypeVar, cast
from typing import Any, Callable, Dict, Literal, Optional, Type, TypeVar, cast
from urllib.parse import urlparse

from httpx import HTTPStatusError, Response
import jwt
import jwt.algorithms
from pydantic import BaseModel

from .constants import API_VERSION_HEADER_NAME, API_VERSIONS
from .errors import (
AuthApiError,
AuthError,
AuthInvalidJwtError,
AuthRetryableError,
AuthUnknownError,
AuthWeakPasswordError,
Expand Down Expand Up @@ -192,15 +195,38 @@ def handle_exception(exception: Exception) -> AuthError:
return AuthUnknownError(get_error_message(error), e)


def decode_jwt_payload(token: str) -> Any:
parts = token.split(".")
if len(parts) != 3:
raise ValueError("JWT is not valid: not a JWT structure")
base64url = parts[1]
def str_from_base64url(base64url: str) -> str:
# Addding padding otherwise the following error happens:
# binascii.Error: Incorrect padding
base64url_with_padding = base64url + "=" * (-len(base64url) % 4)
return loads(urlsafe_b64decode(base64url_with_padding).decode("utf-8"))
return urlsafe_b64decode(base64url_with_padding).decode("utf-8")


def base64url_to_bytes(base64url: str) -> bytes:
# Addding padding otherwise the following error happens:
# binascii.Error: Incorrect padding
base64url_with_padding = base64url + "=" * (-len(base64url) % 4)
return urlsafe_b64decode(base64url_with_padding)


def decode_jwt(token: str) -> Dict[str, Any]:
parts = token.split(".")
if len(parts) != 3:
raise AuthInvalidJwtError("Invalid JWT structure")

# regex check for base64url
if not re.match(BASE64URL_REGEX, parts[1]):
raise AuthInvalidJwtError("JWT not in base64url format")

return {
"header": loads(str_from_base64url(parts[0])),
"payload": loads(str_from_base64url(parts[1])),
"signature": base64url_to_bytes(parts[2]),
"raw": {
"header": parts[0],
"payload": parts[1],
},
}


def generate_pkce_verifier(length=64):
Expand Down Expand Up @@ -267,3 +293,19 @@ def is_valid_jwt(value: str) -> bool:
return False

return True


def validate_exp(exp: int) -> None:
if not exp:
raise AuthInvalidJwtError("JWT has no expiration time")

time_now = datetime.now().timestamp()
if exp <= time_now:
raise AuthInvalidJwtError("JWT has expired")


def get_algorithm(alg: Literal["RS256", "ES256"]) -> jwt.algorithms.Algorithm:
if alg == "RS256":
return jwt.algorithms.RSAAlgorithm
elif alg == "ES256":
return jwt.algorithms.ECAlgorithm
30 changes: 30 additions & 0 deletions supabase_auth/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,36 @@ class SignOutOptions(TypedDict):
scope: NotRequired[SignOutScope]


class JWTHeader(TypedDict):
alg: Literal["RS256", "ES256", "HS256"]
typ: str
kid: str


class RequiredClaims(TypedDict):
iss: str
sub: str
auth: Union[str, List[str]]
exp: int
iat: int
role: str
aal: AuthenticatorAssuranceLevels
session_id: str


class JWTPayload(RequiredClaims, TypedDict, total=False):
pass


class JWK(TypedDict, total=False):
kty: Literal["RSA", "EC", "oct"]
key_ops: List[str]
alg: Optional[str]
kid: Optional[str]

class JWKS(TypedDict):
keys: List[JWK]

for model in [
AMREntry,
AuthResponse,
Expand Down

0 comments on commit 33f00b5

Please sign in to comment.