Skip to content

fix: Remove auth check for cloud tasks #171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/firebase_functions/https_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class AuthData:
The interface for Auth tokens verified in Callable functions
"""

uid: str
uid: str | None
"""
User ID of the ID token.
"""
Expand Down Expand Up @@ -346,8 +346,10 @@ class CallableRequest(_typing.Generic[_core.T]):
_C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]


def _on_call_handler(func: _C2, request: Request,
enforce_app_check: bool) -> Response:
def _on_call_handler(func: _C2,
request: Request,
enforce_app_check: bool,
verify_token: bool = True) -> Response:
try:
if not _util.valid_on_call_request(request):
_logging.error("Invalid request, unable to process.")
Expand All @@ -357,7 +359,8 @@ def _on_call_handler(func: _C2, request: Request,
data=_json.loads(request.data)["data"],
)

token_status = _util.on_call_check_tokens(request)
token_status = _util.on_call_check_tokens(request,
verify_token=verify_token)

if token_status.auth == _util.OnCallTokenState.INVALID:
raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED,
Expand All @@ -377,8 +380,10 @@ def _on_call_handler(func: _C2, request: Request,
if token_status.auth_token is not None:
context = _dataclasses.replace(
context,
auth=AuthData(token_status.auth_token["uid"],
token_status.auth_token),
auth=AuthData(
token_status.auth_token["uid"]
if "uid" in token_status.auth_token else None,
token_status.auth_token),
)

instance_id = request.headers.get("Firebase-Instance-ID-Token")
Expand All @@ -399,7 +404,7 @@ def _on_call_handler(func: _C2, request: Request,
# pylint: disable=broad-except
except Exception as err:
if not isinstance(err, HttpsError):
_logging.error("Unhandled error", err)
_logging.error("Unhandled error: %s", err)
err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
status = err._http_error_code.status
return _make_response(_jsonify(error=err._as_dict()), status)
Expand Down
58 changes: 45 additions & 13 deletions src/firebase_functions/private/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
Module for internal utilities.
"""

import base64
import os as _os
import json as _json
import re as _re
import typing as _typing
import dataclasses as _dataclasses
import datetime as _dt
Expand All @@ -29,6 +31,9 @@
P = _typing.ParamSpec("P")
R = _typing.TypeVar("R")

JWT_REGEX = _re.compile(
r"^[a-zA-Z0-9\-_=]+?\.[a-zA-Z0-9\-_=]+?\.([a-zA-Z0-9\-_=]+)?$")


class Sentinel:
"""Internal class for RESET_VALUE."""
Expand Down Expand Up @@ -204,9 +209,13 @@ def as_dict(self) -> dict:


def _on_call_check_auth_token(
request: _Request
request: _Request,
verify_token: bool = True,
) -> None | _typing.Literal[OnCallTokenState.INVALID] | dict[str, _typing.Any]:
"""Validates the auth token in a callable request."""
"""
Validates the auth token in a callable request.
If verify_token is False, the token will be decoded without verification.
"""
authorization = request.headers.get("Authorization")
if authorization is None:
return None
Expand All @@ -215,13 +224,15 @@ def _on_call_check_auth_token(
return OnCallTokenState.INVALID
try:
id_token = authorization.replace("Bearer ", "")
auth_token = _auth.verify_id_token(id_token)
if verify_token:
auth_token = _auth.verify_id_token(id_token)
else:
auth_token = _unsafe_decode_id_token(id_token)
return auth_token
# pylint: disable=broad-except
except Exception as err:
_logging.error(f"Error validating token: {err}")
return OnCallTokenState.INVALID
return OnCallTokenState.INVALID


def _on_call_check_app_token(
Expand All @@ -240,23 +251,44 @@ def _on_call_check_app_token(
return OnCallTokenState.INVALID


def on_call_check_tokens(request: _Request,) -> _OnCallTokenVerification:
def _unsafe_decode_id_token(token: str):
# Check if the token matches the JWT pattern
if not JWT_REGEX.match(token):
return {}

# Split the token by '.' and decode each component from base64
components = [base64.urlsafe_b64decode(f"{s}==") for s in token.split(".")]

# Attempt to parse the payload (second component) as JSON
payload = components[1].decode("utf-8")
try:
payload = _json.loads(payload)
except _json.JSONDecodeError:
# If there's an error during parsing, ignore it and return the payload as is
pass

return payload


def on_call_check_tokens(request: _Request,
verify_token: bool = True) -> _OnCallTokenVerification:
"""Check tokens"""
verifications = _OnCallTokenVerification()

auth_token = _on_call_check_auth_token(request)
auth_token = _on_call_check_auth_token(request, verify_token=verify_token)
if auth_token is None:
verifications.auth = OnCallTokenState.MISSING
elif isinstance(auth_token, dict):
verifications.auth = OnCallTokenState.VALID
verifications.auth_token = auth_token

app_token = _on_call_check_app_token(request)
if app_token is None:
verifications.app = OnCallTokenState.MISSING
elif isinstance(app_token, dict):
verifications.app = OnCallTokenState.VALID
verifications.app_token = app_token
if verify_token:
app_token = _on_call_check_app_token(request)
if app_token is None:
verifications.app = OnCallTokenState.MISSING
elif isinstance(app_token, dict):
verifications.app = OnCallTokenState.VALID
verifications.app_token = app_token

log_payload = {
**verifications.as_dict(),
Expand All @@ -266,7 +298,7 @@ def on_call_check_tokens(request: _Request,) -> _OnCallTokenVerification:
}

errs = []
if verifications.app == OnCallTokenState.INVALID:
if verify_token and verifications.app == OnCallTokenState.INVALID:
errs.append(("AppCheck token was rejected.", log_payload))

if verifications.auth == OnCallTokenState.INVALID:
Expand Down
5 changes: 4 additions & 1 deletion src/firebase_functions/tasks_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def on_task_dispatched_decorator(func: _C):

@_functools.wraps(func)
def on_task_dispatched_wrapped(request: Request) -> Response:
return _on_call_handler(func, request, enforce_app_check=False)
return _on_call_handler(func,
request,
enforce_app_check=False,
verify_token=False)

_util.set_func_endpoint_attr(
on_task_dispatched_wrapped,
Expand Down
35 changes: 35 additions & 0 deletions tests/test_tasks_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,38 @@ def example(request: CallableRequest[object]) -> str:
response.get_data(as_text=True),
'{"result":"Hello World"}\n',
)

def test_token_is_decoded(self):
"""
Test that the token is decoded instead of verifying auth first.
"""
app = Flask(__name__)

@on_task_dispatched()
def example(request: CallableRequest[object]) -> str:
auth = request.auth
# Make mypy happy
if auth is None:
self.fail("Auth is None")
return "No Auth"
self.assertEqual(auth.token["sub"], "firebase")
self.assertEqual(auth.token["name"], "John Doe")
return "Hello World"

with app.test_request_context("/"):
# pylint: disable=line-too-long
test_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJmaXJlYmFzZSIsIm5hbWUiOiJKb2huIERvZSJ9.74A24Y821E7CZx8aYCsCKo0Y-W0qXwqME-14QlEMcB0"
environ = EnvironBuilder(
method="POST",
headers={
"Authorization": f"Bearer {test_token}"
},
json={
"data": {
"test": "value"
},
},
).get_environ()
request = Request(environ)
response = example(request)
self.assertEqual(response.status_code, 200)
10 changes: 9 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Internal utils tests.
"""
from os import environ, path
from firebase_functions.private.util import firebase_config, microsecond_timestamp_conversion, nanoseconds_timestamp_conversion, get_precision_timestamp, normalize_path, deep_merge, PrecisionTimestamp, second_timestamp_conversion
from firebase_functions.private.util import firebase_config, microsecond_timestamp_conversion, nanoseconds_timestamp_conversion, get_precision_timestamp, normalize_path, deep_merge, PrecisionTimestamp, second_timestamp_conversion, _unsafe_decode_id_token
import datetime as _dt

test_bucket = "python-functions-testing.appspot.com"
Expand Down Expand Up @@ -184,3 +184,11 @@ def test_does_not_modify_originals():
deep_merge(dict1, dict2)
assert dict1["baz"]["answer"] == 42
assert dict2["baz"]["answer"] == 33


def test_unsafe_decode_token():
# pylint: disable=line-too-long
test_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJmaXJlYmFzZSIsIm5hbWUiOiJKb2huIERvZSJ9.74A24Y821E7CZx8aYCsCKo0Y-W0qXwqME-14QlEMcB0"
result = _unsafe_decode_id_token(test_token)
assert result["sub"] == "firebase"
assert result["name"] == "John Doe"