1
1
from __future__ import annotations
2
2
3
+ from typing import TYPE_CHECKING , Any
3
4
from unittest import skip
4
5
5
6
from django .core .cache import cache
9
10
from django .test import TestCase
10
11
11
12
from model_utils import FieldTracker
12
- from model_utils .tracker import DescriptorWrapper
13
+ from model_utils .tracker import DescriptorWrapper , FieldInstanceTracker
13
14
from tests .models import (
14
15
InheritedModelTracked ,
15
16
InheritedTracked ,
28
29
TrackerTimeStamped ,
29
30
)
30
31
32
+ if TYPE_CHECKING :
33
+ MixinBase = TestCase
34
+ else :
35
+ MixinBase = object
31
36
32
- class FieldTrackerTestCase (TestCase ):
33
37
34
- tracker = None
38
+ class FieldTrackerMixin ( MixinBase ):
35
39
36
- def assertHasChanged (self , * , tracker = None , ** kwargs ):
40
+ tracker : FieldInstanceTracker
41
+ instance : models .Model
42
+
43
+ def assertHasChanged (self , * , tracker : FieldInstanceTracker | None = None , ** kwargs : Any ) -> None :
37
44
if tracker is None :
38
45
tracker = self .tracker
39
46
for field , value in kwargs .items ():
@@ -43,29 +50,35 @@ def assertHasChanged(self, *, tracker=None, **kwargs):
43
50
else :
44
51
self .assertEqual (tracker .has_changed (field ), value )
45
52
46
- def assertPrevious (self , * , tracker = None , ** kwargs ) :
53
+ def assertPrevious (self , * , tracker : FieldInstanceTracker | None = None , ** kwargs : Any ) -> None :
47
54
if tracker is None :
48
55
tracker = self .tracker
49
56
for field , value in kwargs .items ():
50
57
self .assertEqual (tracker .previous (field ), value )
51
58
52
- def assertChanged (self , * , tracker = None , ** kwargs ) :
59
+ def assertChanged (self , * , tracker : FieldInstanceTracker | None = None , ** kwargs : Any ) -> None :
53
60
if tracker is None :
54
61
tracker = self .tracker
55
62
self .assertEqual (tracker .changed (), kwargs )
56
63
57
- def assertCurrent (self , * , tracker = None , ** kwargs ) :
64
+ def assertCurrent (self , * , tracker : FieldInstanceTracker | None = None , ** kwargs : Any ) -> None :
58
65
if tracker is None :
59
66
tracker = self .tracker
60
67
self .assertEqual (tracker .current (), kwargs )
61
68
62
- def update_instance (self , ** kwargs ) :
69
+ def update_instance (self , ** kwargs : Any ) -> None :
63
70
for field , value in kwargs .items ():
64
71
setattr (self .instance , field , value )
65
72
self .instance .save ()
66
73
67
74
68
- class FieldTrackerCommonTests :
75
+ class FieldTrackerCommonMixin (FieldTrackerMixin ):
76
+
77
+ instance : (
78
+ Tracked | TrackedNotDefault | TrackedMultiple
79
+ | ModelTracked | ModelTrackedNotDefault | ModelTrackedMultiple
80
+ | TrackedAbstract
81
+ )
69
82
70
83
def test_pre_save_previous (self ) -> None :
71
84
self .assertPrevious (name = None , number = None )
@@ -74,9 +87,10 @@ def test_pre_save_previous(self) -> None:
74
87
self .assertPrevious (name = None , number = None )
75
88
76
89
77
- class FieldTrackerTests (FieldTrackerTestCase , FieldTrackerCommonTests ):
90
+ class FieldTrackerTests (FieldTrackerCommonMixin , TestCase ):
78
91
79
- tracked_class : type [models .Model ] = Tracked
92
+ tracked_class : type [Tracked | ModelTracked | TrackedAbstract ] = Tracked
93
+ instance : Tracked | ModelTracked | TrackedAbstract
80
94
81
95
def setUp (self ) -> None :
82
96
self .instance = self .tracked_class ()
@@ -221,6 +235,7 @@ def test_with_deferred(self) -> None:
221
235
self .instance .number = 1
222
236
self .instance .save ()
223
237
item = self .tracked_class .objects .only ('name' ).first ()
238
+ assert item is not None
224
239
self .assertTrue (item .get_deferred_fields ())
225
240
226
241
# has_changed() returns False for deferred fields, without un-deferring them.
@@ -236,6 +251,7 @@ def test_with_deferred(self) -> None:
236
251
237
252
# examining a deferred field un-defers it
238
253
item = self .tracked_class .objects .only ('name' ).first ()
254
+ assert item is not None
239
255
self .assertEqual (item .number , 1 )
240
256
self .assertTrue ('number' not in item .get_deferred_fields ())
241
257
self .assertEqual (item .tracker .previous ('number' ), 1 )
@@ -254,6 +270,7 @@ def test_with_deferred(self) -> None:
254
270
if self .tracked_class == Tracked :
255
271
256
272
item = self .tracked_class .objects .only ('name' ).first ()
273
+ assert item is not None
257
274
item .number = 2
258
275
259
276
# previous() fetches correct value from database after deferred field is assigned
@@ -280,10 +297,10 @@ def test_with_deferred_fields_access_multiple(self) -> None:
280
297
instance .name
281
298
282
299
283
- class FieldTrackedModelCustomTests (FieldTrackerTestCase ,
284
- FieldTrackerCommonTests ):
300
+ class FieldTrackedModelCustomTests (FieldTrackerCommonMixin , TestCase ):
285
301
286
- tracked_class : type [models .Model ] = TrackedNotDefault
302
+ tracked_class : type [TrackedNotDefault | ModelTrackedNotDefault ] = TrackedNotDefault
303
+ instance : TrackedNotDefault | ModelTrackedNotDefault
287
304
288
305
def setUp (self ) -> None :
289
306
self .instance = self .tracked_class ()
@@ -360,9 +377,10 @@ def test_update_fields(self) -> None:
360
377
self .assertChanged ()
361
378
362
379
363
- class FieldTrackedModelAttributeTests (FieldTrackerTestCase ):
380
+ class FieldTrackedModelAttributeTests (FieldTrackerMixin , TestCase ):
364
381
365
382
tracked_class = TrackedNonFieldAttr
383
+ instance : TrackedNonFieldAttr
366
384
367
385
def setUp (self ) -> None :
368
386
self .instance = self .tracked_class ()
@@ -411,10 +429,10 @@ def test_current(self) -> None:
411
429
self .assertCurrent (rounded = 8 )
412
430
413
431
414
- class FieldTrackedModelMultiTests (FieldTrackerTestCase ,
415
- FieldTrackerCommonTests ):
432
+ class FieldTrackedModelMultiTests (FieldTrackerCommonMixin , TestCase ):
416
433
417
- tracked_class : type [models .Model ] = TrackedMultiple
434
+ tracked_class : type [TrackedMultiple | ModelTrackedMultiple ] = TrackedMultiple
435
+ instance : TrackedMultiple | ModelTrackedMultiple
418
436
419
437
def setUp (self ) -> None :
420
438
self .instance = self .tracked_class ()
@@ -503,10 +521,11 @@ def test_current(self) -> None:
503
521
self .assertCurrent (tracker = self .trackers [1 ], number = 8 )
504
522
505
523
506
- class FieldTrackerForeignKeyTests ( FieldTrackerTestCase ):
524
+ class FieldTrackerForeignKeyMixin ( FieldTrackerMixin ):
507
525
508
- fk_class : type [models .Model ] = Tracked
509
- tracked_class : type [models .Model ] = TrackedFK
526
+ fk_class : type [Tracked | ModelTracked ]
527
+ tracked_class : type [TrackedFK | ModelTrackedFK ]
528
+ instance : TrackedFK | ModelTrackedFK
510
529
511
530
def setUp (self ) -> None :
512
531
self .old_fk = self .fk_class .objects .create (number = 8 )
@@ -545,11 +564,18 @@ def test_custom_without_id(self) -> None:
545
564
self .assertCurrent (fk = self .instance .fk_id )
546
565
547
566
548
- class FieldTrackerForeignKeyPrefetchRelatedTests (FieldTrackerTestCase ):
567
+ class FieldTrackerForeignKeyTests (FieldTrackerForeignKeyMixin , TestCase ):
568
+
569
+ fk_class = Tracked
570
+ tracked_class = TrackedFK
571
+
572
+
573
+ class FieldTrackerForeignKeyPrefetchRelatedTests (FieldTrackerMixin , TestCase ):
549
574
"""Test that using `prefetch_related` on a tracked field does not raise a ValueError."""
550
575
551
576
fk_class = Tracked
552
577
tracked_class = TrackedFK
578
+ instance : TrackedFK
553
579
554
580
def setUp (self ) -> None :
555
581
model_tracked = self .fk_class .objects .create (name = "" , number = 0 )
@@ -568,10 +594,11 @@ def test_custom_without_id(self) -> None:
568
594
self .assertIsNotNone (list (self .tracked_class .objects .prefetch_related ("fk" )))
569
595
570
596
571
- class FieldTrackerTimeStampedTests (FieldTrackerTestCase ):
597
+ class FieldTrackerTimeStampedTests (FieldTrackerMixin , TestCase ):
572
598
573
599
fk_class = Tracked
574
600
tracked_class = TrackerTimeStamped
601
+ instance : TrackerTimeStamped
575
602
576
603
def setUp (self ) -> None :
577
604
self .instance = self .tracked_class .objects .create (name = 'old' , number = 1 )
@@ -607,9 +634,10 @@ class FieldTrackerInheritedForeignKeyTests(FieldTrackerForeignKeyTests):
607
634
tracked_class = InheritedTrackedFK
608
635
609
636
610
- class FieldTrackerFileFieldTests (FieldTrackerTestCase ):
637
+ class FieldTrackerFileFieldTests (FieldTrackerMixin , TestCase ):
611
638
612
639
tracked_class = TrackedFileField
640
+ instance : TrackedFileField
613
641
614
642
def setUp (self ) -> None :
615
643
self .instance = self .tracked_class ()
@@ -631,7 +659,7 @@ def test_saved_data_without_instance(self) -> None:
631
659
self .assertEqual (self .tracker .saved_data , {})
632
660
self .update_instance (some_file = self .some_file )
633
661
field_file_copy = self .tracker .saved_data .get ('some_file' )
634
- self . assertIsNotNone ( field_file_copy )
662
+ assert field_file_copy is not None
635
663
self .assertEqual (field_file_copy .__getstate__ ().get ('instance' ), None )
636
664
self .assertEqual (self .instance .some_file .instance , self .instance )
637
665
self .assertIsInstance (self .instance .some_file , FieldFile )
@@ -732,7 +760,8 @@ def test_current(self) -> None:
732
760
733
761
class ModelTrackerTests (FieldTrackerTests ):
734
762
735
- tracked_class : type [models .Model ] = ModelTracked
763
+ tracked_class : type [ModelTracked | TrackedAbstract ] = ModelTracked
764
+ instance : ModelTracked
736
765
737
766
def test_cache_compatible (self ) -> None :
738
767
cache .set ('key' , self .instance )
@@ -848,10 +877,11 @@ def test_pre_save_changed(self) -> None:
848
877
self .assertChanged ()
849
878
850
879
851
- class ModelTrackerForeignKeyTests (FieldTrackerForeignKeyTests ):
880
+ class ModelTrackerForeignKeyTests (FieldTrackerForeignKeyMixin , TestCase ):
852
881
853
882
fk_class = ModelTracked
854
883
tracked_class = ModelTrackedFK
884
+ instance : ModelTrackedFK
855
885
856
886
def test_custom_without_id (self ) -> None :
857
887
with self .assertNumQueries (2 ):
@@ -889,11 +919,11 @@ def setUp(self) -> None:
889
919
self .instance = Tracked .objects .create (number = 1 )
890
920
self .tracker = self .instance .tracker
891
921
892
- def assertChanged (self , * fields ) :
922
+ def assertChanged (self , * fields : str ) -> None :
893
923
for f in fields :
894
924
self .assertTrue (self .tracker .has_changed (f ))
895
925
896
- def assertNotChanged (self , * fields ) :
926
+ def assertNotChanged (self , * fields : str ) -> None :
897
927
for f in fields :
898
928
self .assertFalse (self .tracker .has_changed (f ))
899
929
@@ -924,7 +954,7 @@ def test_context_manager_fields(self) -> None:
924
954
def test_tracker_decorator (self ) -> None :
925
955
926
956
@Tracked .tracker
927
- def tracked_method (obj ) :
957
+ def tracked_method (obj : Tracked ) -> None :
928
958
obj .name = 'new'
929
959
self .assertChanged ('name' )
930
960
@@ -935,7 +965,7 @@ def tracked_method(obj):
935
965
def test_tracker_decorator_fields (self ) -> None :
936
966
937
967
@Tracked .tracker (fields = ['name' ])
938
- def tracked_method (obj ) :
968
+ def tracked_method (obj : Tracked ) -> None :
939
969
obj .name = 'new'
940
970
obj .number += 1
941
971
self .assertChanged ('name' , 'number' )
0 commit comments