Skip to content

Commit efebabb

Browse files
authored
fix: Remove auth check for cloud tasks (#171)
1 parent 75d5a1b commit efebabb

File tree

5 files changed

+105
-22
lines changed

5 files changed

+105
-22
lines changed

src/firebase_functions/https_fn.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ class AuthData:
280280
The interface for Auth tokens verified in Callable functions
281281
"""
282282

283-
uid: str
283+
uid: str | None
284284
"""
285285
User ID of the ID token.
286286
"""
@@ -346,8 +346,10 @@ class CallableRequest(_typing.Generic[_core.T]):
346346
_C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]
347347

348348

349-
def _on_call_handler(func: _C2, request: Request,
350-
enforce_app_check: bool) -> Response:
349+
def _on_call_handler(func: _C2,
350+
request: Request,
351+
enforce_app_check: bool,
352+
verify_token: bool = True) -> Response:
351353
try:
352354
if not _util.valid_on_call_request(request):
353355
_logging.error("Invalid request, unable to process.")
@@ -357,7 +359,8 @@ def _on_call_handler(func: _C2, request: Request,
357359
data=_json.loads(request.data)["data"],
358360
)
359361

360-
token_status = _util.on_call_check_tokens(request)
362+
token_status = _util.on_call_check_tokens(request,
363+
verify_token=verify_token)
361364

362365
if token_status.auth == _util.OnCallTokenState.INVALID:
363366
raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED,
@@ -377,8 +380,10 @@ def _on_call_handler(func: _C2, request: Request,
377380
if token_status.auth_token is not None:
378381
context = _dataclasses.replace(
379382
context,
380-
auth=AuthData(token_status.auth_token["uid"],
381-
token_status.auth_token),
383+
auth=AuthData(
384+
token_status.auth_token["uid"]
385+
if "uid" in token_status.auth_token else None,
386+
token_status.auth_token),
382387
)
383388

384389
instance_id = request.headers.get("Firebase-Instance-ID-Token")
@@ -399,7 +404,7 @@ def _on_call_handler(func: _C2, request: Request,
399404
# pylint: disable=broad-except
400405
except Exception as err:
401406
if not isinstance(err, HttpsError):
402-
_logging.error("Unhandled error", err)
407+
_logging.error("Unhandled error: %s", err)
403408
err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
404409
status = err._http_error_code.status
405410
return _make_response(_jsonify(error=err._as_dict()), status)

src/firebase_functions/private/util.py

+45-13
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
Module for internal utilities.
1616
"""
1717

18+
import base64
1819
import os as _os
1920
import json as _json
21+
import re as _re
2022
import typing as _typing
2123
import dataclasses as _dataclasses
2224
import datetime as _dt
@@ -29,6 +31,9 @@
2931
P = _typing.ParamSpec("P")
3032
R = _typing.TypeVar("R")
3133

34+
JWT_REGEX = _re.compile(
35+
r"^[a-zA-Z0-9\-_=]+?\.[a-zA-Z0-9\-_=]+?\.([a-zA-Z0-9\-_=]+)?$")
36+
3237

3338
class Sentinel:
3439
"""Internal class for RESET_VALUE."""
@@ -204,9 +209,13 @@ def as_dict(self) -> dict:
204209

205210

206211
def _on_call_check_auth_token(
207-
request: _Request
212+
request: _Request,
213+
verify_token: bool = True,
208214
) -> None | _typing.Literal[OnCallTokenState.INVALID] | dict[str, _typing.Any]:
209-
"""Validates the auth token in a callable request."""
215+
"""
216+
Validates the auth token in a callable request.
217+
If verify_token is False, the token will be decoded without verification.
218+
"""
210219
authorization = request.headers.get("Authorization")
211220
if authorization is None:
212221
return None
@@ -215,13 +224,15 @@ def _on_call_check_auth_token(
215224
return OnCallTokenState.INVALID
216225
try:
217226
id_token = authorization.replace("Bearer ", "")
218-
auth_token = _auth.verify_id_token(id_token)
227+
if verify_token:
228+
auth_token = _auth.verify_id_token(id_token)
229+
else:
230+
auth_token = _unsafe_decode_id_token(id_token)
219231
return auth_token
220232
# pylint: disable=broad-except
221233
except Exception as err:
222234
_logging.error(f"Error validating token: {err}")
223235
return OnCallTokenState.INVALID
224-
return OnCallTokenState.INVALID
225236

226237

227238
def _on_call_check_app_token(
@@ -240,23 +251,44 @@ def _on_call_check_app_token(
240251
return OnCallTokenState.INVALID
241252

242253

243-
def on_call_check_tokens(request: _Request,) -> _OnCallTokenVerification:
254+
def _unsafe_decode_id_token(token: str):
255+
# Check if the token matches the JWT pattern
256+
if not JWT_REGEX.match(token):
257+
return {}
258+
259+
# Split the token by '.' and decode each component from base64
260+
components = [base64.urlsafe_b64decode(f"{s}==") for s in token.split(".")]
261+
262+
# Attempt to parse the payload (second component) as JSON
263+
payload = components[1].decode("utf-8")
264+
try:
265+
payload = _json.loads(payload)
266+
except _json.JSONDecodeError:
267+
# If there's an error during parsing, ignore it and return the payload as is
268+
pass
269+
270+
return payload
271+
272+
273+
def on_call_check_tokens(request: _Request,
274+
verify_token: bool = True) -> _OnCallTokenVerification:
244275
"""Check tokens"""
245276
verifications = _OnCallTokenVerification()
246277

247-
auth_token = _on_call_check_auth_token(request)
278+
auth_token = _on_call_check_auth_token(request, verify_token=verify_token)
248279
if auth_token is None:
249280
verifications.auth = OnCallTokenState.MISSING
250281
elif isinstance(auth_token, dict):
251282
verifications.auth = OnCallTokenState.VALID
252283
verifications.auth_token = auth_token
253284

254-
app_token = _on_call_check_app_token(request)
255-
if app_token is None:
256-
verifications.app = OnCallTokenState.MISSING
257-
elif isinstance(app_token, dict):
258-
verifications.app = OnCallTokenState.VALID
259-
verifications.app_token = app_token
285+
if verify_token:
286+
app_token = _on_call_check_app_token(request)
287+
if app_token is None:
288+
verifications.app = OnCallTokenState.MISSING
289+
elif isinstance(app_token, dict):
290+
verifications.app = OnCallTokenState.VALID
291+
verifications.app_token = app_token
260292

261293
log_payload = {
262294
**verifications.as_dict(),
@@ -266,7 +298,7 @@ def on_call_check_tokens(request: _Request,) -> _OnCallTokenVerification:
266298
}
267299

268300
errs = []
269-
if verifications.app == OnCallTokenState.INVALID:
301+
if verify_token and verifications.app == OnCallTokenState.INVALID:
270302
errs.append(("AppCheck token was rejected.", log_payload))
271303

272304
if verifications.auth == OnCallTokenState.INVALID:

src/firebase_functions/tasks_fn.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ def on_task_dispatched_decorator(func: _C):
5353

5454
@_functools.wraps(func)
5555
def on_task_dispatched_wrapped(request: Request) -> Response:
56-
return _on_call_handler(func, request, enforce_app_check=False)
56+
return _on_call_handler(func,
57+
request,
58+
enforce_app_check=False,
59+
verify_token=False)
5760

5861
_util.set_func_endpoint_attr(
5962
on_task_dispatched_wrapped,

tests/test_tasks_fn.py

+35
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,38 @@ def example(request: CallableRequest[object]) -> str:
6868
response.get_data(as_text=True),
6969
'{"result":"Hello World"}\n',
7070
)
71+
72+
def test_token_is_decoded(self):
73+
"""
74+
Test that the token is decoded instead of verifying auth first.
75+
"""
76+
app = Flask(__name__)
77+
78+
@on_task_dispatched()
79+
def example(request: CallableRequest[object]) -> str:
80+
auth = request.auth
81+
# Make mypy happy
82+
if auth is None:
83+
self.fail("Auth is None")
84+
return "No Auth"
85+
self.assertEqual(auth.token["sub"], "firebase")
86+
self.assertEqual(auth.token["name"], "John Doe")
87+
return "Hello World"
88+
89+
with app.test_request_context("/"):
90+
# pylint: disable=line-too-long
91+
test_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJmaXJlYmFzZSIsIm5hbWUiOiJKb2huIERvZSJ9.74A24Y821E7CZx8aYCsCKo0Y-W0qXwqME-14QlEMcB0"
92+
environ = EnvironBuilder(
93+
method="POST",
94+
headers={
95+
"Authorization": f"Bearer {test_token}"
96+
},
97+
json={
98+
"data": {
99+
"test": "value"
100+
},
101+
},
102+
).get_environ()
103+
request = Request(environ)
104+
response = example(request)
105+
self.assertEqual(response.status_code, 200)

tests/test_util.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Internal utils tests.
1616
"""
1717
from os import environ, path
18-
from firebase_functions.private.util import firebase_config, microsecond_timestamp_conversion, nanoseconds_timestamp_conversion, get_precision_timestamp, normalize_path, deep_merge, PrecisionTimestamp, second_timestamp_conversion
18+
from firebase_functions.private.util import firebase_config, microsecond_timestamp_conversion, nanoseconds_timestamp_conversion, get_precision_timestamp, normalize_path, deep_merge, PrecisionTimestamp, second_timestamp_conversion, _unsafe_decode_id_token
1919
import datetime as _dt
2020

2121
test_bucket = "python-functions-testing.appspot.com"
@@ -184,3 +184,11 @@ def test_does_not_modify_originals():
184184
deep_merge(dict1, dict2)
185185
assert dict1["baz"]["answer"] == 42
186186
assert dict2["baz"]["answer"] == 33
187+
188+
189+
def test_unsafe_decode_token():
190+
# pylint: disable=line-too-long
191+
test_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJmaXJlYmFzZSIsIm5hbWUiOiJKb2huIERvZSJ9.74A24Y821E7CZx8aYCsCKo0Y-W0qXwqME-14QlEMcB0"
192+
result = _unsafe_decode_id_token(test_token)
193+
assert result["sub"] == "firebase"
194+
assert result["name"] == "John Doe"

0 commit comments

Comments
 (0)