Skip to content

Commit e19773b

Browse files
committed
[celery] Drop celery.RequestIDAwareTask in favor of signals
1 parent 780d34c commit e19773b

File tree

4 files changed

+33
-72
lines changed

4 files changed

+33
-72
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,14 @@ In order to use this feature you need to enable the celery plugin and configure
9090
use `current_request_id()` from inside your worker
9191

9292
```python
93-
from flask_log_request_id.extras.celery import RequestIDAwareTask
93+
from flask_log_request_id.extras.celery import enable_request_id_propagation
9494
from flask_log_request_id import current_request_id
9595
from celery.app import Celery
9696
import logging
9797

98-
celery = Celery(task_cls=RequestIDAwareTask)
98+
celery = Celery()
99+
enable_request_id_propagation(celery) # << This step here is critical to propagate request-id to workers
100+
99101
app = Flask()
100102

101103
@celery.task()
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from celery import Celery, signals
2-
from flask_log_request_id.extras.celery import add_request_id_header, RequestIDAwareTask
2+
from flask_log_request_id.extras.celery import enable_request_id_propagation
33

4-
# Either sublcass all tasks from RequestIDAwareTask
5-
celery = Celery(task_cls=RequestIDAwareTask)
4+
celery = Celery()
5+
6+
# You need to enable propagation on celery application
7+
enable_request_id_propagation(celery)
68

7-
# Or alternatively register the provided signal handler
8-
signals.before_task_publish.connect(add_request_id_header)

flask_log_request_id/extras/celery.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from celery import Task, current_task
1+
from celery import current_task, signals
22
import logging as _logging
33

44
from ..request_id import current_request_id
@@ -9,24 +9,26 @@
99
logger = _logging.getLogger(__name__)
1010

1111

12-
class RequestIDAwareTask(Task):
12+
def enable_request_id_propagation(celery_app):
1313
"""
14-
Task base class that injects request id to task request object with key 'x_request_id'.
14+
Will attach signal on celery application in order to propagate
15+
current request id to workers
16+
:param celery_app: The celery application
1517
"""
18+
signals.before_task_publish.connect(on_before_publish_insert_request_id_header)
1619

17-
def apply_async(self, *args, **kwargs):
18-
# Set default value for 'headers' argument
19-
if 'headers' not in kwargs or kwargs['headers'] is None:
20-
kwargs['headers'] = {}
20+
21+
def on_before_publish_insert_request_id_header(headers, **kwargs):
22+
"""
23+
This function is meant to be used as signal processor for "before_task_publish".
24+
:param Dict headers: The headers of the message
25+
:param kwargs: Any extra keyword arguments
26+
"""
27+
if _CELERY_X_HEADER not in headers:
2128
request_id = current_request_id()
29+
headers[_CELERY_X_HEADER] = request_id
2230
logger.debug("Forwarding request_id '{}' to the task consumer.".format(request_id))
23-
kwargs['headers'][_CELERY_X_HEADER] = request_id
24-
25-
return super(RequestIDAwareTask, self).apply_async(*args, **kwargs)
2631

27-
def add_request_id_header(headers=None, **kwargs):
28-
if _CELERY_X_HEADER not in headers:
29-
headers[_CELERY_X_HEADER] = current_request_id()
3032

3133
def ctx_celery_task_get_request_id():
3234
"""

tests/extras/celery_tests.py

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33

44
from celery import Celery
55
from flask_log_request_id.extras.celery import (ExecutedOutsideContext,
6-
RequestIDAwareTask,
7-
add_request_id_header,
6+
on_before_publish_insert_request_id_header,
87
ctx_celery_task_get_request_id)
98

109

@@ -23,69 +22,27 @@ def apply_async(self, *args, **kwargs):
2322
class CeleryIntegrationTestCase(unittest.TestCase):
2423

2524
@mock.patch('flask_log_request_id.extras.celery.current_request_id')
26-
def test_mixin_injection(self, mocked_current_request_id):
27-
28-
patcher = mock.patch.object(RequestIDAwareTask, '__bases__', (MockedTask,))
29-
30-
with patcher:
31-
patcher.is_local = True
32-
33-
mocked_current_request_id.return_value = 15
34-
task = RequestIDAwareTask()
35-
task.apply_async('test', foo='bar')
36-
self.assertEqual(
37-
task.apply_async_called['args'],
38-
('test', ))
39-
40-
self.assertDictEqual(
41-
task.apply_async_called['kwargs'], {
42-
'headers': {'x_request_id': 15},
43-
'foo': 'bar'
44-
})
45-
46-
@mock.patch('flask_log_request_id.extras.celery.current_request_id')
47-
def test_issue21_called_with_headers_None(self, mocked_current_request_id):
48-
49-
patcher = mock.patch.object(RequestIDAwareTask, '__bases__', (MockedTask,))
50-
51-
with patcher:
52-
patcher.is_local = True
53-
54-
mocked_current_request_id.return_value = 15
55-
task = RequestIDAwareTask()
56-
task.apply_async('test', foo='bar', headers=None)
57-
self.assertEqual(
58-
task.apply_async_called['args'],
59-
('test', ))
60-
61-
self.assertDictEqual(
62-
task.apply_async_called['kwargs'], {
63-
'headers': {'x_request_id': 15},
64-
'foo': 'bar'
65-
})
66-
67-
@mock.patch('flask_log_request_id.extras.celery.current_request_id')
68-
def test_before_task_publish_hooks_adds_header(self, mocked_current_request_id):
25+
def test_enable_request_id_propagation(self, mocked_current_request_id):
6926
mocked_current_request_id.return_value = 15
7027

7128
headers = {}
72-
add_request_id_header(headers={})
73-
print(headers)
74-
self.assertDictEqual(headers, {
75-
'x_request_id': 15
76-
})
29+
on_before_publish_insert_request_id_header(headers=headers)
30+
self.assertDictEqual(
31+
{
32+
'x_request_id': 15
33+
},
34+
headers)
7735

7836
@mock.patch('flask_log_request_id.extras.celery.current_task')
7937
def test_ctx_fetcher_outside_context(self, mocked_current_task):
80-
8138
mocked_current_task._get_current_object.return_value = None
8239
with self.assertRaises(ExecutedOutsideContext):
8340
ctx_celery_task_get_request_id()
8441

8542
@mock.patch('flask_log_request_id.extras.celery.current_task')
8643
def test_ctx_fetcher_inside_context(self, mocked_current_task):
8744
mocked_current_task._get_current_object.return_value = True
88-
mocked_current_task.request.get.side_effect = lambda a, default: {'x_request_id': 15, 'other':'bar'}[a]
45+
mocked_current_task.request.get.side_effect = lambda a, default: {'x_request_id': 15, 'other': 'bar'}[a]
8946

9047
self.assertEqual(ctx_celery_task_get_request_id(), 15)
9148

0 commit comments

Comments
 (0)