Skip to content

Commit 87dfa2e

Browse files
committed
Annotate test_field_tracker module
1 parent d617b26 commit 87dfa2e

File tree

3 files changed

+65
-35
lines changed

3 files changed

+65
-35
lines changed

tests/test_fields/test_field_tracker.py

+61-31
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from typing import TYPE_CHECKING, Any
34
from unittest import skip
45

56
from django.core.cache import cache
@@ -9,7 +10,7 @@
910
from django.test import TestCase
1011

1112
from model_utils import FieldTracker
12-
from model_utils.tracker import DescriptorWrapper
13+
from model_utils.tracker import DescriptorWrapper, FieldInstanceTracker
1314
from tests.models import (
1415
InheritedModelTracked,
1516
InheritedTracked,
@@ -28,12 +29,18 @@
2829
TrackerTimeStamped,
2930
)
3031

32+
if TYPE_CHECKING:
33+
MixinBase = TestCase
34+
else:
35+
MixinBase = object
3136

32-
class FieldTrackerTestCase(TestCase):
3337

34-
tracker = None
38+
class FieldTrackerMixin(MixinBase):
3539

36-
def assertHasChanged(self, *, tracker=None, **kwargs):
40+
tracker: FieldInstanceTracker
41+
instance: models.Model
42+
43+
def assertHasChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
3744
if tracker is None:
3845
tracker = self.tracker
3946
for field, value in kwargs.items():
@@ -43,29 +50,35 @@ def assertHasChanged(self, *, tracker=None, **kwargs):
4350
else:
4451
self.assertEqual(tracker.has_changed(field), value)
4552

46-
def assertPrevious(self, *, tracker=None, **kwargs):
53+
def assertPrevious(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
4754
if tracker is None:
4855
tracker = self.tracker
4956
for field, value in kwargs.items():
5057
self.assertEqual(tracker.previous(field), value)
5158

52-
def assertChanged(self, *, tracker=None, **kwargs):
59+
def assertChanged(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
5360
if tracker is None:
5461
tracker = self.tracker
5562
self.assertEqual(tracker.changed(), kwargs)
5663

57-
def assertCurrent(self, *, tracker=None, **kwargs):
64+
def assertCurrent(self, *, tracker: FieldInstanceTracker | None = None, **kwargs: Any) -> None:
5865
if tracker is None:
5966
tracker = self.tracker
6067
self.assertEqual(tracker.current(), kwargs)
6168

62-
def update_instance(self, **kwargs):
69+
def update_instance(self, **kwargs: Any) -> None:
6370
for field, value in kwargs.items():
6471
setattr(self.instance, field, value)
6572
self.instance.save()
6673

6774

68-
class FieldTrackerCommonTests:
75+
class FieldTrackerCommonMixin(FieldTrackerMixin):
76+
77+
instance: (
78+
Tracked | TrackedNotDefault | TrackedMultiple
79+
| ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple
80+
| TrackedAbstract
81+
)
6982

7083
def test_pre_save_previous(self) -> None:
7184
self.assertPrevious(name=None, number=None)
@@ -74,9 +87,10 @@ def test_pre_save_previous(self) -> None:
7487
self.assertPrevious(name=None, number=None)
7588

7689

77-
class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests):
90+
class FieldTrackerTests(FieldTrackerCommonMixin, TestCase):
7891

79-
tracked_class: type[models.Model] = Tracked
92+
tracked_class: type[Tracked | ModelTracked | TrackedAbstract] = Tracked
93+
instance: Tracked | ModelTracked | TrackedAbstract
8094

8195
def setUp(self) -> None:
8296
self.instance = self.tracked_class()
@@ -221,6 +235,7 @@ def test_with_deferred(self) -> None:
221235
self.instance.number = 1
222236
self.instance.save()
223237
item = self.tracked_class.objects.only('name').first()
238+
assert item is not None
224239
self.assertTrue(item.get_deferred_fields())
225240

226241
# has_changed() returns False for deferred fields, without un-deferring them.
@@ -236,6 +251,7 @@ def test_with_deferred(self) -> None:
236251

237252
# examining a deferred field un-defers it
238253
item = self.tracked_class.objects.only('name').first()
254+
assert item is not None
239255
self.assertEqual(item.number, 1)
240256
self.assertTrue('number' not in item.get_deferred_fields())
241257
self.assertEqual(item.tracker.previous('number'), 1)
@@ -254,6 +270,7 @@ def test_with_deferred(self) -> None:
254270
if self.tracked_class == Tracked:
255271

256272
item = self.tracked_class.objects.only('name').first()
273+
assert item is not None
257274
item.number = 2
258275

259276
# previous() fetches correct value from database after deferred field is assigned
@@ -280,10 +297,10 @@ def test_with_deferred_fields_access_multiple(self) -> None:
280297
instance.name
281298

282299

283-
class FieldTrackedModelCustomTests(FieldTrackerTestCase,
284-
FieldTrackerCommonTests):
300+
class FieldTrackedModelCustomTests(FieldTrackerCommonMixin, TestCase):
285301

286-
tracked_class: type[models.Model] = TrackedNotDefault
302+
tracked_class: type[TrackedNotDefault | ModelTrackedNotDefault] = TrackedNotDefault
303+
instance: TrackedNotDefault | ModelTrackedNotDefault
287304

288305
def setUp(self) -> None:
289306
self.instance = self.tracked_class()
@@ -360,9 +377,10 @@ def test_update_fields(self) -> None:
360377
self.assertChanged()
361378

362379

363-
class FieldTrackedModelAttributeTests(FieldTrackerTestCase):
380+
class FieldTrackedModelAttributeTests(FieldTrackerMixin, TestCase):
364381

365382
tracked_class = TrackedNonFieldAttr
383+
instance: TrackedNonFieldAttr
366384

367385
def setUp(self) -> None:
368386
self.instance = self.tracked_class()
@@ -411,10 +429,10 @@ def test_current(self) -> None:
411429
self.assertCurrent(rounded=8)
412430

413431

414-
class FieldTrackedModelMultiTests(FieldTrackerTestCase,
415-
FieldTrackerCommonTests):
432+
class FieldTrackedModelMultiTests(FieldTrackerCommonMixin, TestCase):
416433

417-
tracked_class: type[models.Model] = TrackedMultiple
434+
tracked_class: type[TrackedMultiple | ModelTrackedMultiple] = TrackedMultiple
435+
instance: TrackedMultiple | ModelTrackedMultiple
418436

419437
def setUp(self) -> None:
420438
self.instance = self.tracked_class()
@@ -503,10 +521,11 @@ def test_current(self) -> None:
503521
self.assertCurrent(tracker=self.trackers[1], number=8)
504522

505523

506-
class FieldTrackerForeignKeyTests(FieldTrackerTestCase):
524+
class FieldTrackerForeignKeyMixin(FieldTrackerMixin):
507525

508-
fk_class: type[models.Model] = Tracked
509-
tracked_class: type[models.Model] = TrackedFK
526+
fk_class: type[Tracked | ModelTracked]
527+
tracked_class: type[TrackedFK | ModelTrackedFK]
528+
instance: TrackedFK | ModelTrackedFK
510529

511530
def setUp(self) -> None:
512531
self.old_fk = self.fk_class.objects.create(number=8)
@@ -545,11 +564,18 @@ def test_custom_without_id(self) -> None:
545564
self.assertCurrent(fk=self.instance.fk_id)
546565

547566

548-
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerTestCase):
567+
class FieldTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
568+
569+
fk_class = Tracked
570+
tracked_class = TrackedFK
571+
572+
573+
class FieldTrackerForeignKeyPrefetchRelatedTests(FieldTrackerMixin, TestCase):
549574
"""Test that using `prefetch_related` on a tracked field does not raise a ValueError."""
550575

551576
fk_class = Tracked
552577
tracked_class = TrackedFK
578+
instance: TrackedFK
553579

554580
def setUp(self) -> None:
555581
model_tracked = self.fk_class.objects.create(name="", number=0)
@@ -568,10 +594,11 @@ def test_custom_without_id(self) -> None:
568594
self.assertIsNotNone(list(self.tracked_class.objects.prefetch_related("fk")))
569595

570596

571-
class FieldTrackerTimeStampedTests(FieldTrackerTestCase):
597+
class FieldTrackerTimeStampedTests(FieldTrackerMixin, TestCase):
572598

573599
fk_class = Tracked
574600
tracked_class = TrackerTimeStamped
601+
instance: TrackerTimeStamped
575602

576603
def setUp(self) -> None:
577604
self.instance = self.tracked_class.objects.create(name='old', number=1)
@@ -607,9 +634,10 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests):
607634
tracked_class = InheritedTrackedFK
608635

609636

610-
class FieldTrackerFileFieldTests(FieldTrackerTestCase):
637+
class FieldTrackerFileFieldTests(FieldTrackerMixin, TestCase):
611638

612639
tracked_class = TrackedFileField
640+
instance: TrackedFileField
613641

614642
def setUp(self) -> None:
615643
self.instance = self.tracked_class()
@@ -631,7 +659,7 @@ def test_saved_data_without_instance(self) -> None:
631659
self.assertEqual(self.tracker.saved_data, {})
632660
self.update_instance(some_file=self.some_file)
633661
field_file_copy = self.tracker.saved_data.get('some_file')
634-
self.assertIsNotNone(field_file_copy)
662+
assert field_file_copy is not None
635663
self.assertEqual(field_file_copy.__getstate__().get('instance'), None)
636664
self.assertEqual(self.instance.some_file.instance, self.instance)
637665
self.assertIsInstance(self.instance.some_file, FieldFile)
@@ -732,7 +760,8 @@ def test_current(self) -> None:
732760

733761
class ModelTrackerTests(FieldTrackerTests):
734762

735-
tracked_class: type[models.Model] = ModelTracked
763+
tracked_class: type[ModelTracked | TrackedAbstract] = ModelTracked
764+
instance: ModelTracked
736765

737766
def test_cache_compatible(self) -> None:
738767
cache.set('key', self.instance)
@@ -848,10 +877,11 @@ def test_pre_save_changed(self) -> None:
848877
self.assertChanged()
849878

850879

851-
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyTests):
880+
class ModelTrackerForeignKeyTests(FieldTrackerForeignKeyMixin, TestCase):
852881

853882
fk_class = ModelTracked
854883
tracked_class = ModelTrackedFK
884+
instance: ModelTrackedFK
855885

856886
def test_custom_without_id(self) -> None:
857887
with self.assertNumQueries(2):
@@ -889,11 +919,11 @@ def setUp(self) -> None:
889919
self.instance = Tracked.objects.create(number=1)
890920
self.tracker = self.instance.tracker
891921

892-
def assertChanged(self, *fields):
922+
def assertChanged(self, *fields: str) -> None:
893923
for f in fields:
894924
self.assertTrue(self.tracker.has_changed(f))
895925

896-
def assertNotChanged(self, *fields):
926+
def assertNotChanged(self, *fields: str) -> None:
897927
for f in fields:
898928
self.assertFalse(self.tracker.has_changed(f))
899929

@@ -924,7 +954,7 @@ def test_context_manager_fields(self) -> None:
924954
def test_tracker_decorator(self) -> None:
925955

926956
@Tracked.tracker
927-
def tracked_method(obj):
957+
def tracked_method(obj: Tracked) -> None:
928958
obj.name = 'new'
929959
self.assertChanged('name')
930960

@@ -935,7 +965,7 @@ def tracked_method(obj):
935965
def test_tracker_decorator_fields(self) -> None:
936966

937967
@Tracked.tracker(fields=['name'])
938-
def tracked_method(obj):
968+
def tracked_method(obj: Tracked) -> None:
939969
obj.name = 'new'
940970
obj.number += 1
941971
self.assertChanged('name', 'number')

tests/test_fields/test_monitor_field.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def test_double_save(self) -> None:
3434

3535
def test_no_monitor_arg(self) -> None:
3636
with self.assertRaises(TypeError):
37-
MonitorField()
37+
MonitorField() # type: ignore[call-arg]
3838

3939
def test_nullable_without_default_deprecation(self) -> None:
4040
warning_message = (

tests/test_fields/test_urlsafe_token_field.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_factory_default(self) -> None:
3131

3232
def test_factory_not_callable(self) -> None:
3333
with self.assertRaises(TypeError):
34-
UrlsafeTokenField(factory='INVALID')
34+
UrlsafeTokenField(factory='INVALID') # type: ignore[arg-type]
3535

3636
def test_get_default(self) -> None:
3737
field = UrlsafeTokenField()
@@ -57,8 +57,8 @@ def test_no_default_param(self) -> None:
5757
self.assertIs(field.default, NOT_PROVIDED)
5858

5959
def test_deconstruct(self) -> None:
60-
def test_factory() -> None:
61-
pass
60+
def test_factory(max_length: int) -> str:
61+
assert False
6262
instance = UrlsafeTokenField(factory=test_factory)
6363
name, path, args, kwargs = instance.deconstruct()
6464
new_instance = UrlsafeTokenField(*args, **kwargs)

0 commit comments

Comments
 (0)