10
10
# See the License for the specific language governing permissions and
11
11
# limitations under the License.
12
12
13
+ from __future__ import annotations
14
+
13
15
import functools
14
16
import hashlib
15
17
import logging
16
18
import os
17
19
import time
20
+ import typing
18
21
import urllib .parse
19
22
20
23
import celery
32
35
from warehouse .config import Environment
33
36
from warehouse .metrics import IMetricsService
34
37
38
+ if typing .TYPE_CHECKING :
39
+ from pyramid .request import Request
40
+
35
41
# We need to trick Celery into supporting rediss:// URLs which is how redis-py
36
42
# signals that you should use Redis with TLS.
37
43
celery .app .backends .BACKEND_ALIASES ["rediss" ] = (
@@ -53,7 +59,19 @@ def _params_from_url(self, url, defaults):
53
59
54
60
55
61
class WarehouseTask (celery .Task ):
56
- def __new__ (cls , * args , ** kwargs ):
62
+ """
63
+ A custom Celery Task that integrates with Pyramid's transaction manager and
64
+ metrics service.
65
+ """
66
+
67
+ __header__ : typing .Callable
68
+ _wh_original_run : typing .Callable
69
+
70
+ def __new__ (cls , * args , ** kwargs ) -> WarehouseTask :
71
+ """
72
+ Override to wrap the `run` method of the task with a new method that
73
+ will handle exceptions from the task and retry them if they're retryable.
74
+ """
57
75
obj = super ().__new__ (cls , * args , ** kwargs )
58
76
if getattr (obj , "__header__" , None ) is not None :
59
77
obj .__header__ = functools .partial (obj .__header__ , object ())
@@ -82,16 +100,34 @@ def run(*args, **kwargs):
82
100
metrics .increment ("warehouse.task.failed" , tags = metric_tags )
83
101
raise
84
102
85
- obj ._wh_original_run , obj .run = obj .run , run
103
+ # Reassign the `run` method to the new one we've created.
104
+ obj ._wh_original_run , obj .run = obj .run , run # type: ignore[method-assign]
86
105
87
106
return obj
88
107
89
108
def __call__ (self , * args , ** kwargs ):
109
+ """
110
+ Override to inject a faux request object into the task when it's called.
111
+ There's no WSGI request object available when a task is called, so we
112
+ create a fake one here. This is necessary as a lot of our code assumes
113
+ that there's a Pyramid request object available.
114
+ """
90
115
return super ().__call__ (* (self .get_request (),) + args , ** kwargs )
91
116
92
- def get_request (self ):
117
+ def get_request (self ) -> Request :
118
+ """
119
+ Get a request object to use for this task.
120
+
121
+ This will either return the request object that was injected into the
122
+ task when it was called, or it will create a new request object to use
123
+ for the task.
124
+
125
+ Note: The `type: ignore` comments are necessary because the `pyramid_env`
126
+ attribute is not defined on the request object, but we're adding it
127
+ dynamically.
128
+ """
93
129
if not hasattr (self .request , "pyramid_env" ):
94
- registry = self .app .pyramid_config .registry
130
+ registry = self .app .pyramid_config .registry # type: ignore[attr-defined]
95
131
env = pyramid .scripting .prepare (registry = registry )
96
132
env ["request" ].tm = transaction .TransactionManager (explicit = True )
97
133
env ["request" ].timings = {"new_request_start" : time .time () * 1000 }
@@ -101,15 +137,29 @@ def get_request(self):
101
137
).hexdigest ()
102
138
self .request .update (pyramid_env = env )
103
139
104
- return self .request .pyramid_env ["request" ]
140
+ return self .request .pyramid_env ["request" ] # type: ignore[attr-defined]
105
141
106
142
def after_return (self , status , retval , task_id , args , kwargs , einfo ):
143
+ """
144
+ Called after the task has returned. This is where we'll clean up the
145
+ request object that we injected into the task.
146
+ """
107
147
if hasattr (self .request , "pyramid_env" ):
108
148
pyramid_env = self .request .pyramid_env
109
149
pyramid_env ["request" ]._process_finished_callbacks ()
110
150
pyramid_env ["closer" ]()
111
151
112
152
def apply_async (self , * args , ** kwargs ):
153
+ """
154
+ Override the apply_async method to add an after commit hook to the
155
+ transaction manager to send the task after the transaction has been
156
+ committed.
157
+
158
+ This is necessary because we want to ensure that the task is only sent
159
+ after the transaction has been committed. This is important because we
160
+ want to ensure that the task is only sent if the transaction was
161
+ successful.
162
+ """
113
163
# The API design of Celery makes this threadlocal pretty impossible to
114
164
# avoid :(
115
165
request = get_current_request ()
@@ -137,17 +187,51 @@ def apply_async(self, *args, **kwargs):
137
187
)
138
188
139
189
def retry (self , * args , ** kwargs ):
190
+ """
191
+ Override the retry method to increment a metric when a task is retried.
192
+
193
+ This is necessary because the `retry` method is called when a task is
194
+ retried, and we want to track how many times a task has been retried.
195
+ """
140
196
request = get_current_request ()
141
197
metrics = request .find_service (IMetricsService , context = None )
142
198
metrics .increment ("warehouse.task.retried" , tags = [f"task:{ self .name } " ])
143
199
return super ().retry (* args , ** kwargs )
144
200
145
201
def _after_commit_hook (self , success , * args , ** kwargs ):
202
+ """
203
+ This is the hook that gets called after the transaction has been
204
+ committed. We'll only send the task if the transaction was successful.
205
+ """
146
206
if success :
147
207
super ().apply_async (* args , ** kwargs )
148
208
149
209
150
210
def task (** kwargs ):
211
+ """
212
+ A decorator that can be used to define a Celery task.
213
+
214
+ A thin wrapper around Celery's `task` decorator that allows us to attach
215
+ the task to the Celery app when the configuration is scanned during the
216
+ application startup.
217
+
218
+ This decorator also sets the `shared` option to `False` by default. This
219
+ means that the task will be created anew for each worker process that is
220
+ started. This is important because the `WarehouseTask` class that we use
221
+ for our tasks is not thread-safe, so we need to ensure that each worker
222
+ process has its own instance of the task.
223
+
224
+ This decorator also adds the task to the `warehouse` category in the
225
+ configuration scanner. This is important because we use this category to
226
+ find all the tasks that have been defined in the configuration.
227
+
228
+ Example usage:
229
+ ```
230
+ @tasks.task(...)
231
+ def my_task(self, *args, **kwargs):
232
+ pass
233
+ ```
234
+ """
151
235
kwargs .setdefault ("shared" , False )
152
236
153
237
def deco (wrapped ):
@@ -193,7 +277,7 @@ def add_task():
193
277
def includeme (config ):
194
278
s = config .registry .settings
195
279
196
- broker_transport_options = {}
280
+ broker_transport_options : dict [ str , str | dict ] = {}
197
281
198
282
broker_url = s .get ("celery.broker_url" )
199
283
if broker_url is None :
0 commit comments