Skip to content

Commit ec10639

Browse files
authored
feat: split tasks https (#201)
1 parent c1c801c commit ec10639

File tree

4 files changed

+60
-66
lines changed

4 files changed

+60
-66
lines changed

Diff for: src/firebase_functions/https_fn.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,8 @@ class CallableRequest(_typing.Generic[_core.T]):
352352
_C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]
353353

354354

355-
def _on_call_handler(func: _C2,
356-
request: Request,
357-
enforce_app_check: bool,
358-
verify_token: bool = True) -> Response:
355+
def _on_call_handler(func: _C2, request: Request,
356+
enforce_app_check: bool) -> Response:
359357
try:
360358
if not _util.valid_on_call_request(request):
361359
_logging.error("Invalid request, unable to process.")
@@ -365,8 +363,7 @@ def _on_call_handler(func: _C2,
365363
data=_json.loads(request.data)["data"],
366364
)
367365

368-
token_status = _util.on_call_check_tokens(request,
369-
verify_token=verify_token)
366+
token_status = _util.on_call_check_tokens(request)
370367

371368
if token_status.auth == _util.OnCallTokenState.INVALID:
372369
raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED,
@@ -420,7 +417,7 @@ def _on_call_handler(func: _C2,
420417
def on_request(**kwargs) -> _typing.Callable[[_C1], _C1]:
421418
"""
422419
Handler which handles HTTPS requests.
423-
Requires a function that takes a ``Request`` and ``Response`` object,
420+
Requires a function that takes a ``Request`` and ``Response`` object,
424421
the same signature as a Flask app.
425422
426423
Example:

Diff for: src/firebase_functions/private/util.py

+12-18
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,10 @@ def as_dict(self) -> dict:
212212

213213

214214
def _on_call_check_auth_token(
215-
request: _Request,
216-
verify_token: bool = True,
215+
request: _Request
217216
) -> None | _typing.Literal[OnCallTokenState.INVALID] | dict[str, _typing.Any]:
218217
"""
219-
Validates the auth token in a callable request.
218+
Validates the auth token in a callable request.
220219
If verify_token is False, the token will be decoded without verification.
221220
"""
222221
authorization = request.headers.get("Authorization")
@@ -227,10 +226,7 @@ def _on_call_check_auth_token(
227226
return OnCallTokenState.INVALID
228227
try:
229228
id_token = authorization.replace("Bearer ", "")
230-
if verify_token:
231-
auth_token = _auth.verify_id_token(id_token)
232-
else:
233-
auth_token = _unsafe_decode_id_token(id_token)
229+
auth_token = _auth.verify_id_token(id_token)
234230
return auth_token
235231
# pylint: disable=broad-except
236232
except Exception as err:
@@ -273,25 +269,23 @@ def _unsafe_decode_id_token(token: str):
273269
return payload
274270

275271

276-
def on_call_check_tokens(request: _Request,
277-
verify_token: bool = True) -> _OnCallTokenVerification:
272+
def on_call_check_tokens(request: _Request) -> _OnCallTokenVerification:
278273
"""Check tokens"""
279274
verifications = _OnCallTokenVerification()
280275

281-
auth_token = _on_call_check_auth_token(request, verify_token=verify_token)
276+
auth_token = _on_call_check_auth_token(request)
282277
if auth_token is None:
283278
verifications.auth = OnCallTokenState.MISSING
284279
elif isinstance(auth_token, dict):
285280
verifications.auth = OnCallTokenState.VALID
286281
verifications.auth_token = auth_token
287282

288-
if verify_token:
289-
app_token = _on_call_check_app_token(request)
290-
if app_token is None:
291-
verifications.app = OnCallTokenState.MISSING
292-
elif isinstance(app_token, dict):
293-
verifications.app = OnCallTokenState.VALID
294-
verifications.app_token = app_token
283+
app_token = _on_call_check_app_token(request)
284+
if app_token is None:
285+
verifications.app = OnCallTokenState.MISSING
286+
elif isinstance(app_token, dict):
287+
verifications.app = OnCallTokenState.VALID
288+
verifications.app_token = app_token
295289

296290
log_payload = {
297291
**verifications.as_dict(),
@@ -301,7 +295,7 @@ def on_call_check_tokens(request: _Request,
301295
}
302296

303297
errs = []
304-
if verify_token and verifications.app == OnCallTokenState.INVALID:
298+
if verifications.app == OnCallTokenState.INVALID:
305299
errs.append(("AppCheck token was rejected.", log_payload))
306300

307301
if verifications.auth == OnCallTokenState.INVALID:

Diff for: src/firebase_functions/tasks_fn.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,55 @@
1616
# pylint: disable=protected-access
1717
import typing as _typing
1818
import functools as _functools
19+
import dataclasses as _dataclasses
20+
import json as _json
1921

20-
from flask import Request, Response
22+
from flask import Request, Response, make_response as _make_response, jsonify as _jsonify
2123

24+
import firebase_functions.core as _core
2225
import firebase_functions.options as _options
2326
import firebase_functions.private.util as _util
24-
from firebase_functions.https_fn import CallableRequest, _on_call_handler
27+
from firebase_functions.https_fn import CallableRequest, HttpsError, FunctionsErrorCode
28+
29+
from functions_framework import logging as _logging
2530

2631
_C = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]
32+
_C1 = _typing.Callable[[Request], Response]
33+
_C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]
34+
35+
36+
def _on_call_handler(func: _C2, request: Request) -> Response:
37+
try:
38+
if not _util.valid_on_call_request(request):
39+
_logging.error("Invalid request, unable to process.")
40+
raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request")
41+
context: CallableRequest = CallableRequest(
42+
raw_request=request,
43+
data=_json.loads(request.data)["data"],
44+
)
45+
46+
instance_id = request.headers.get("Firebase-Instance-ID-Token")
47+
if instance_id is not None:
48+
# Validating the token requires an http request, so we don't do it.
49+
# If the user wants to use it for something, it will be validated then.
50+
# Currently, the only real use case for this token is for sending
51+
# pushes with FCM. In that case, the FCM APIs will validate the token.
52+
context = _dataclasses.replace(
53+
context,
54+
instance_id_token=request.headers.get(
55+
"Firebase-Instance-ID-Token"),
56+
)
57+
result = _core._with_init(func)(context)
58+
return _jsonify(result=result)
59+
# Disable broad exceptions lint since we want to handle all exceptions here
60+
# and wrap as an HttpsError.
61+
# pylint: disable=broad-except
62+
except Exception as err:
63+
if not isinstance(err, HttpsError):
64+
_logging.error("Unhandled error: %s", err)
65+
err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
66+
status = err._http_error_code.status
67+
return _make_response(_jsonify(error=err._as_dict()), status)
2768

2869

2970
@_util.copy_func_kwargs(_options.TaskQueueOptions)
@@ -53,10 +94,7 @@ def on_task_dispatched_decorator(func: _C):
5394

5495
@_functools.wraps(func)
5596
def on_task_dispatched_wrapped(request: Request) -> Response:
56-
return _on_call_handler(func,
57-
request,
58-
enforce_app_check=False,
59-
verify_token=False)
97+
return _on_call_handler(func, request)
6098

6199
_util.set_func_endpoint_attr(
62100
on_task_dispatched_wrapped,

Diff for: tests/test_tasks_fn.py

-35
Original file line numberDiff line numberDiff line change
@@ -71,41 +71,6 @@ def example(request: CallableRequest[object]) -> str:
7171
'{"result":"Hello World"}\n',
7272
)
7373

74-
def test_token_is_decoded(self):
75-
"""
76-
Test that the token is decoded instead of verifying auth first.
77-
"""
78-
app = Flask(__name__)
79-
80-
@on_task_dispatched()
81-
def example(request: CallableRequest[object]) -> str:
82-
auth = request.auth
83-
# Make mypy happy
84-
if auth is None:
85-
self.fail("Auth is None")
86-
return "No Auth"
87-
self.assertEqual(auth.token["sub"], "firebase")
88-
self.assertEqual(auth.token["name"], "John Doe")
89-
return "Hello World"
90-
91-
with app.test_request_context("/"):
92-
# pylint: disable=line-too-long
93-
test_token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJmaXJlYmFzZSIsIm5hbWUiOiJKb2huIERvZSJ9.74A24Y821E7CZx8aYCsCKo0Y-W0qXwqME-14QlEMcB0"
94-
environ = EnvironBuilder(
95-
method="POST",
96-
headers={
97-
"Authorization": f"Bearer {test_token}"
98-
},
99-
json={
100-
"data": {
101-
"test": "value"
102-
},
103-
},
104-
).get_environ()
105-
request = Request(environ)
106-
response = example(request)
107-
self.assertEqual(response.status_code, 200)
108-
10974
def test_calls_init(self):
11075
hello = None
11176

0 commit comments

Comments
 (0)