Skip to content

Commit e80ac40

Browse files
authored
lint/doc updates (#17310)
1 parent 775dc3b commit e80ac40

File tree

5 files changed

+124
-28
lines changed

5 files changed

+124
-28
lines changed

warehouse/cli/db/dbml.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import json
1414

15-
from collections.abc import Iterator
15+
from collections.abc import Iterable
1616
from typing import Literal, NotRequired, TypedDict
1717

1818
import click
@@ -95,7 +95,7 @@ class TableInfo(TypedDict):
9595
comment: NotRequired[str]
9696

9797

98-
def generate_dbml_file(tables: Iterator[Table], _output: str | None) -> None:
98+
def generate_dbml_file(tables: Iterable[Table], _output: str | None) -> None:
9999
file = click.open_file(_output, "w") if _output else click.open_file("-", "w")
100100

101101
tables_info = {}

warehouse/packaging/models.py

+27-15
Original file line numberDiff line numberDiff line change
@@ -333,33 +333,45 @@ def __acl__(self):
333333
acls.append((Allow, f"oidc:{publisher.id}", [Permissions.ProjectsUpload]))
334334

335335
# Get all of the users for this project.
336-
query = session.query(Role).filter(Role.project == self)
337-
query = query.options(orm.lazyload(Role.project))
338-
query = query.options(orm.lazyload(Role.user))
336+
user_query = (
337+
session.query(Role)
338+
.filter(Role.project == self)
339+
.options(orm.lazyload(Role.project), orm.lazyload(Role.user))
340+
)
339341
permissions = {
340342
(role.user_id, "Administer" if role.role_name == "Owner" else "Upload")
341-
for role in query.all()
343+
for role in user_query.all()
342344
}
343345

344346
# Add all of the team members for this project.
345-
query = session.query(TeamProjectRole).filter(TeamProjectRole.project == self)
346-
query = query.options(orm.lazyload(TeamProjectRole.project))
347-
query = query.options(orm.lazyload(TeamProjectRole.team))
348-
for role in query.all():
347+
team_query = (
348+
session.query(TeamProjectRole)
349+
.filter(TeamProjectRole.project == self)
350+
.options(
351+
orm.lazyload(TeamProjectRole.project),
352+
orm.lazyload(TeamProjectRole.team),
353+
)
354+
)
355+
for role in team_query.all():
349356
permissions |= {
350357
(user.id, "Administer" if role.role_name.value == "Owner" else "Upload")
351358
for user in role.team.members
352359
}
353360

354361
# Add all organization owners for this project.
355362
if self.organization:
356-
query = session.query(OrganizationRole).filter(
357-
OrganizationRole.organization == self.organization,
358-
OrganizationRole.role_name == OrganizationRoleType.Owner,
363+
org_query = (
364+
session.query(OrganizationRole)
365+
.filter(
366+
OrganizationRole.organization == self.organization,
367+
OrganizationRole.role_name == OrganizationRoleType.Owner,
368+
)
369+
.options(
370+
orm.lazyload(OrganizationRole.organization),
371+
orm.lazyload(OrganizationRole.user),
372+
)
359373
)
360-
query = query.options(orm.lazyload(OrganizationRole.organization))
361-
query = query.options(orm.lazyload(OrganizationRole.user))
362-
permissions |= {(role.user_id, "Administer") for role in query.all()}
374+
permissions |= {(role.user_id, "Administer") for role in org_query.all()}
363375

364376
for user_id, permission_name in sorted(permissions, key=lambda x: (x[1], x[0])):
365377
# Disallow Write permissions for Projects in quarantine, allow Upload
@@ -759,7 +771,7 @@ def urls_by_verify_status(self, *, verified: bool):
759771
return _urls
760772

761773
def verified_user_name_and_repo_name(
762-
self, domains: set[str], reserved_names: typing.Sequence[str] | None = None
774+
self, domains: set[str], reserved_names: typing.Collection[str] | None = None
763775
):
764776
for _, url in self.urls_by_verify_status(verified=True).items():
765777
try:

warehouse/search/tasks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
def _project_docs(db, project_name=None):
4343
releases_list = (
4444
select(Release.id)
45-
.filter(Release.yanked.is_(False), Release.files)
45+
.filter(Release.yanked.is_(False), Release.files.any())
4646
.order_by(
4747
Release.project_id,
4848
Release.is_prerelease.nullslast(),

warehouse/sitemap/views.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ def sitemap_index(request):
8888
.group_by(User.sitemap_bucket)
8989
.all()
9090
)
91-
buckets = {}
91+
buckets: dict[str, datetime.datetime] = {}
9292
for b in itertools.chain(projects, users):
9393
current = buckets.setdefault(b.sitemap_bucket, b.modified)
9494
if current is None or (b.modified is not None and b.modified > current):
9595
buckets[b.sitemap_bucket] = b.modified
96-
buckets = [Bucket(name=k, modified=v) for k, v in buckets.items()]
97-
buckets.sort(key=lambda x: x.name)
96+
bucket_list = [Bucket(name=k, modified=v) for k, v in buckets.items()]
97+
bucket_list.sort(key=lambda x: x.name)
9898

99-
return {"buckets": buckets}
99+
return {"buckets": bucket_list}
100100

101101

102102
@view_config(

warehouse/tasks.py

+90-6
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13+
from __future__ import annotations
14+
1315
import functools
1416
import hashlib
1517
import logging
1618
import os
1719
import time
20+
import typing
1821
import urllib.parse
1922

2023
import celery
@@ -32,6 +35,9 @@
3235
from warehouse.config import Environment
3336
from warehouse.metrics import IMetricsService
3437

38+
if typing.TYPE_CHECKING:
39+
from pyramid.request import Request
40+
3541
# We need to trick Celery into supporting rediss:// URLs which is how redis-py
3642
# signals that you should use Redis with TLS.
3743
celery.app.backends.BACKEND_ALIASES["rediss"] = (
@@ -53,7 +59,19 @@ def _params_from_url(self, url, defaults):
5359

5460

5561
class WarehouseTask(celery.Task):
56-
def __new__(cls, *args, **kwargs):
62+
"""
63+
A custom Celery Task that integrates with Pyramid's transaction manager and
64+
metrics service.
65+
"""
66+
67+
__header__: typing.Callable
68+
_wh_original_run: typing.Callable
69+
70+
def __new__(cls, *args, **kwargs) -> WarehouseTask:
71+
"""
72+
Override to wrap the `run` method of the task with a new method that
73+
will handle exceptions from the task and retry them if they're retryable.
74+
"""
5775
obj = super().__new__(cls, *args, **kwargs)
5876
if getattr(obj, "__header__", None) is not None:
5977
obj.__header__ = functools.partial(obj.__header__, object())
@@ -82,16 +100,34 @@ def run(*args, **kwargs):
82100
metrics.increment("warehouse.task.failed", tags=metric_tags)
83101
raise
84102

85-
obj._wh_original_run, obj.run = obj.run, run
103+
# Reassign the `run` method to the new one we've created.
104+
obj._wh_original_run, obj.run = obj.run, run # type: ignore[method-assign]
86105

87106
return obj
88107

89108
def __call__(self, *args, **kwargs):
109+
"""
110+
Override to inject a faux request object into the task when it's called.
111+
There's no WSGI request object available when a task is called, so we
112+
create a fake one here. This is necessary as a lot of our code assumes
113+
that there's a Pyramid request object available.
114+
"""
90115
return super().__call__(*(self.get_request(),) + args, **kwargs)
91116

92-
def get_request(self):
117+
def get_request(self) -> Request:
118+
"""
119+
Get a request object to use for this task.
120+
121+
This will either return the request object that was injected into the
122+
task when it was called, or it will create a new request object to use
123+
for the task.
124+
125+
Note: The `type: ignore` comments are necessary because the `pyramid_env`
126+
attribute is not defined on the request object, but we're adding it
127+
dynamically.
128+
"""
93129
if not hasattr(self.request, "pyramid_env"):
94-
registry = self.app.pyramid_config.registry
130+
registry = self.app.pyramid_config.registry # type: ignore[attr-defined]
95131
env = pyramid.scripting.prepare(registry=registry)
96132
env["request"].tm = transaction.TransactionManager(explicit=True)
97133
env["request"].timings = {"new_request_start": time.time() * 1000}
@@ -101,15 +137,29 @@ def get_request(self):
101137
).hexdigest()
102138
self.request.update(pyramid_env=env)
103139

104-
return self.request.pyramid_env["request"]
140+
return self.request.pyramid_env["request"] # type: ignore[attr-defined]
105141

106142
def after_return(self, status, retval, task_id, args, kwargs, einfo):
143+
"""
144+
Called after the task has returned. This is where we'll clean up the
145+
request object that we injected into the task.
146+
"""
107147
if hasattr(self.request, "pyramid_env"):
108148
pyramid_env = self.request.pyramid_env
109149
pyramid_env["request"]._process_finished_callbacks()
110150
pyramid_env["closer"]()
111151

112152
def apply_async(self, *args, **kwargs):
153+
"""
154+
Override the apply_async method to add an after commit hook to the
155+
transaction manager to send the task after the transaction has been
156+
committed.
157+
158+
This is necessary because we want to ensure that the task is only sent
159+
after the transaction has been committed. This is important because we
160+
want to ensure that the task is only sent if the transaction was
161+
successful.
162+
"""
113163
# The API design of Celery makes this threadlocal pretty impossible to
114164
# avoid :(
115165
request = get_current_request()
@@ -137,17 +187,51 @@ def apply_async(self, *args, **kwargs):
137187
)
138188

139189
def retry(self, *args, **kwargs):
190+
"""
191+
Override the retry method to increment a metric when a task is retried.
192+
193+
This is necessary because the `retry` method is called when a task is
194+
retried, and we want to track how many times a task has been retried.
195+
"""
140196
request = get_current_request()
141197
metrics = request.find_service(IMetricsService, context=None)
142198
metrics.increment("warehouse.task.retried", tags=[f"task:{self.name}"])
143199
return super().retry(*args, **kwargs)
144200

145201
def _after_commit_hook(self, success, *args, **kwargs):
202+
"""
203+
This is the hook that gets called after the transaction has been
204+
committed. We'll only send the task if the transaction was successful.
205+
"""
146206
if success:
147207
super().apply_async(*args, **kwargs)
148208

149209

150210
def task(**kwargs):
211+
"""
212+
A decorator that can be used to define a Celery task.
213+
214+
A thin wrapper around Celery's `task` decorator that allows us to attach
215+
the task to the Celery app when the configuration is scanned during the
216+
application startup.
217+
218+
This decorator also sets the `shared` option to `False` by default. This
219+
means that the task will be created anew for each worker process that is
220+
started. This is important because the `WarehouseTask` class that we use
221+
for our tasks is not thread-safe, so we need to ensure that each worker
222+
process has its own instance of the task.
223+
224+
This decorator also adds the task to the `warehouse` category in the
225+
configuration scanner. This is important because we use this category to
226+
find all the tasks that have been defined in the configuration.
227+
228+
Example usage:
229+
```
230+
@tasks.task(...)
231+
def my_task(self, *args, **kwargs):
232+
pass
233+
```
234+
"""
151235
kwargs.setdefault("shared", False)
152236

153237
def deco(wrapped):
@@ -193,7 +277,7 @@ def add_task():
193277
def includeme(config):
194278
s = config.registry.settings
195279

196-
broker_transport_options = {}
280+
broker_transport_options: dict[str, str | dict] = {}
197281

198282
broker_url = s.get("celery.broker_url")
199283
if broker_url is None:

0 commit comments

Comments
 (0)