Skip to content

Commit e613301

Browse files
committed
Annotate test helpers
1 parent ee391ca commit e613301

File tree

3 files changed

+41
-22
lines changed

3 files changed

+41
-22
lines changed

tests/fields.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
from django.db import models
6+
from django.db.backends.base.base import BaseDatabaseWrapper
47

58

6-
def mutable_from_db(value):
9+
def mutable_from_db(value: object) -> Any:
710
if value == '':
811
return None
912
try:
@@ -14,7 +17,7 @@ def mutable_from_db(value):
1417
return value
1518

1619

17-
def mutable_to_db(value):
20+
def mutable_to_db(value: object) -> str:
1821
if value is None:
1922
return ''
2023
if isinstance(value, list):
@@ -23,12 +26,12 @@ def mutable_to_db(value):
2326

2427

2528
class MutableField(models.TextField):
26-
def to_python(self, value):
29+
def to_python(self, value: object) -> Any:
2730
return mutable_from_db(value)
2831

29-
def from_db_value(self, value, expression, connection):
32+
def from_db_value(self, value: object, expression: object, connection: BaseDatabaseWrapper) -> Any:
3033
return mutable_from_db(value)
3134

32-
def get_db_prep_save(self, value, connection):
35+
def get_db_prep_save(self, value: object, connection: BaseDatabaseWrapper) -> str:
3336
value = super().get_db_prep_save(value, connection)
3437
return mutable_to_db(value)

tests/managers.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from __future__ import annotations
22

3-
from model_utils.managers import SoftDeletableManager, SoftDeletableQuerySet
3+
from model_utils.managers import (
4+
ModelT,
5+
QuerySet,
6+
SoftDeletableManager,
7+
SoftDeletableQuerySet,
8+
)
49

510

6-
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet):
7-
def only_read(self):
11+
class CustomSoftDeleteQuerySet(SoftDeletableQuerySet[ModelT]):
12+
def only_read(self) -> QuerySet[ModelT]:
813
return self.filter(is_read=True)
914

1015

11-
class CustomSoftDeleteManager(SoftDeletableManager):
16+
class CustomSoftDeleteManager(SoftDeletableManager[ModelT]):
1217
_queryset_class = CustomSoftDeleteQuerySet
1318

14-
def only_read(self):
15-
return self.get_queryset().only_read()
19+
def only_read(self) -> QuerySet[ModelT]:
20+
qs = self.get_queryset()
21+
assert isinstance(qs, self._queryset_class), qs
22+
return qs.only_read()

tests/models.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import ClassVar, TypeVar
3+
from typing import Any, ClassVar, TypeVar, overload
44

55
from django.db import models
66
from django.db.models import Manager
@@ -40,7 +40,7 @@ class InheritanceManagerTestParent(models.Model):
4040
on_delete=models.CASCADE)
4141
objects: ClassVar[InheritanceManager[InheritanceManagerTestParent]] = InheritanceManager()
4242

43-
def __str__(self):
43+
def __str__(self) -> str:
4444
return "{}({})".format(
4545
self.__class__.__name__[len('InheritanceManagerTest'):],
4646
self.pk,
@@ -214,7 +214,7 @@ class Tracked(models.Model):
214214

215215
tracker = FieldTracker()
216216

217-
def save(self, *args, **kwargs):
217+
def save(self, *args: Any, **kwargs: Any) -> None:
218218
""" No-op save() to ensure that FieldTracker.patch_save() works. """
219219
super().save(*args, **kwargs)
220220

@@ -226,7 +226,7 @@ class TrackerTimeStamped(TimeStampedModel):
226226

227227
tracker = FieldTracker()
228228

229-
def save(self, *args, **kwargs):
229+
def save(self, *args: Any, **kwargs: Any) -> None:
230230
""" Automatically add "modified" to update_fields."""
231231
update_fields = kwargs.get('update_fields')
232232
if update_fields is not None:
@@ -244,7 +244,7 @@ class TrackedFK(models.Model):
244244

245245
class TrackedAbstract(AbstractTracked):
246246
name = models.CharField(max_length=20)
247-
number = models.IntegerField() # type: ignore[assignment]
247+
number = models.IntegerField()
248248
mutable = MutableField(default=None)
249249

250250
tracker = FieldTracker()
@@ -261,7 +261,7 @@ class TrackedNonFieldAttr(models.Model):
261261
number = models.FloatField()
262262

263263
@property
264-
def rounded(self):
264+
def rounded(self) -> int | None:
265265
return round(self.number) if self.number is not None else None
266266

267267
tracker = FieldTracker(fields=['rounded'])
@@ -360,28 +360,37 @@ class StringyDescriptor:
360360
"""
361361
Descriptor that returns a string version of the underlying integer value.
362362
"""
363-
def __init__(self, name):
363+
def __init__(self, name: str):
364364
self.name = name
365365

366-
def __get__(self, obj, cls=None):
366+
@overload
367+
def __get__(self, obj: None, cls: type[models.Model] | None = None) -> StringyDescriptor:
368+
...
369+
370+
@overload
371+
def __get__(self, obj: models.Model, cls: type[models.Model]) -> str:
372+
...
373+
374+
def __get__(self, obj: models.Model | None, cls: type[models.Model] | None = None) -> StringyDescriptor | str:
367375
if obj is None:
368376
return self
369377
if self.name in obj.get_deferred_fields():
370378
# This queries the database, and sets the value on the instance.
379+
assert cls is not None
371380
fields_map = {f.name: f for f in cls._meta.fields}
372381
field = fields_map[self.name]
373382
DeferredAttribute(field=field).__get__(obj, cls)
374383
return str(obj.__dict__[self.name])
375384

376-
def __set__(self, obj, value):
385+
def __set__(self, obj: object, value: str) -> None:
377386
obj.__dict__[self.name] = int(value)
378387

379-
def __delete__(self, obj):
388+
def __delete__(self, obj: object) -> None:
380389
del obj.__dict__[self.name]
381390

382391

383392
class CustomDescriptorField(models.IntegerField):
384-
def contribute_to_class(self, cls, name, *args, **kwargs):
393+
def contribute_to_class(self, cls: type[models.Model], name: str, *args: Any, **kwargs: Any) -> None:
385394
super().contribute_to_class(cls, name, *args, **kwargs)
386395
setattr(cls, name, StringyDescriptor(name))
387396

0 commit comments

Comments
 (0)