2
2
3
3
from copy import deepcopy
4
4
from functools import wraps
5
- from typing import TYPE_CHECKING , Any , Iterable , TypeVar , cast , overload
5
+ from typing import (
6
+ TYPE_CHECKING ,
7
+ Any ,
8
+ Generic ,
9
+ Iterable ,
10
+ Protocol ,
11
+ TypeVar ,
12
+ cast ,
13
+ overload ,
14
+ )
6
15
7
16
from django .core .exceptions import FieldError
8
17
from django .db import models
@@ -16,10 +25,22 @@ class _AugmentedModel(models.Model):
16
25
_instance_initialized : bool
17
26
_deferred_fields : set [str ]
18
27
19
-
20
28
T = TypeVar ("T" )
21
29
22
30
31
+ class Descriptor (Protocol [T ]):
32
+ def __get__ (self , instance : object , owner : type [object ]) -> T :
33
+ ...
34
+
35
+ def __set__ (self , instance : object , value : T ) -> None :
36
+ ...
37
+
38
+
39
+ class FullDescriptor (Descriptor [T ]):
40
+ def __delete__ (self , instance : object ) -> None :
41
+ ...
42
+
43
+
23
44
class LightStateFieldFile (FieldFile ):
24
45
"""
25
46
FieldFile subclass with the only aim to remove the instance from the state.
@@ -53,22 +74,22 @@ def lightweight_deepcopy(value: T) -> T:
53
74
return deepcopy (value )
54
75
55
76
56
- class DescriptorWrapper :
77
+ class DescriptorWrapper ( Generic [ T ]) :
57
78
58
- def __init__ (self , field_name : str , descriptor : models . Field , tracker_attname : str ):
79
+ def __init__ (self , field_name : str , descriptor : Descriptor [ T ] , tracker_attname : str ):
59
80
self .field_name = field_name
60
81
self .descriptor = descriptor
61
82
self .tracker_attname = tracker_attname
62
83
63
84
@overload
64
- def __get__ (self , instance : None , owner : type [models .Model ]) -> DescriptorWrapper :
85
+ def __get__ (self , instance : None , owner : type [models .Model ]) -> DescriptorWrapper [ T ] :
65
86
...
66
87
67
88
@overload
68
- def __get__ (self , instance : models .Model , owner : type [models .Model ]) -> models . Field :
89
+ def __get__ (self , instance : models .Model , owner : type [models .Model ]) -> T :
69
90
...
70
91
71
- def __get__ (self , instance : models .Model | None , owner : type [models .Model ]) -> DescriptorWrapper | models . Field :
92
+ def __get__ (self , instance : models .Model | None , owner : type [models .Model ]) -> DescriptorWrapper [ T ] | T :
72
93
if instance is None :
73
94
return self
74
95
was_deferred = self .field_name in instance .get_deferred_fields ()
@@ -78,7 +99,7 @@ def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> D
78
99
tracker_instance .saved_data [self .field_name ] = lightweight_deepcopy (value )
79
100
return value
80
101
81
- def __set__ (self , instance : models .Model , value : models . Field ) -> None :
102
+ def __set__ (self , instance : models .Model , value : T ) -> None :
82
103
initialized = hasattr (instance , '_instance_initialized' )
83
104
was_deferred = self .field_name in instance .get_deferred_fields ()
84
105
@@ -101,23 +122,23 @@ def __set__(self, instance: models.Model, value: models.Field) -> None:
101
122
else :
102
123
instance .__dict__ [self .field_name ] = value
103
124
104
- def __getattr__ (self , attr : str ) -> models . Field :
125
+ def __getattr__ (self , attr : str ) -> T :
105
126
return getattr (self .descriptor , attr )
106
127
107
128
@staticmethod
108
- def cls_for_descriptor (descriptor : models . Field ) -> type [DescriptorWrapper ]:
129
+ def cls_for_descriptor (descriptor : Descriptor [ T ] ) -> type [DescriptorWrapper [ T ] ]:
109
130
if hasattr (descriptor , '__delete__' ):
110
131
return FullDescriptorWrapper
111
132
else :
112
133
return DescriptorWrapper
113
134
114
135
115
- class FullDescriptorWrapper (DescriptorWrapper ):
136
+ class FullDescriptorWrapper (DescriptorWrapper [ T ] ):
116
137
"""
117
138
Wrapper for descriptors with all three descriptor methods.
118
139
"""
119
- def __delete__ (self , obj : models .Field ) -> None :
120
- self .descriptor .__delete__ (obj ) # type: ignore[attr-defined]
140
+ def __delete__ (self , obj : models .Model ) -> None :
141
+ cast ( FullDescriptor [ T ], self .descriptor ) .__delete__ (obj )
121
142
122
143
123
144
class FieldsContext :
@@ -255,7 +276,9 @@ def has_changed(self, field: str) -> bool:
255
276
# deferred fields haven't changed
256
277
if field in self .deferred_fields and field not in self .instance .__dict__ :
257
278
return False
258
- return self .previous (field ) != self .get_field_value (field )
279
+ prev : object = self .previous (field )
280
+ curr : object = self .get_field_value (field )
281
+ return prev != curr
259
282
else :
260
283
raise FieldError ('field "%s" not tracked' % field )
261
284
@@ -348,7 +371,7 @@ def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None:
348
371
self .fields = (field .attname for field in sender ._meta .fields )
349
372
self .fields = set (self .fields )
350
373
for field_name in self .fields :
351
- descriptor : models .Field = getattr (sender , field_name )
374
+ descriptor : models .Field [ Any , Any ] = getattr (sender , field_name )
352
375
wrapper_cls = DescriptorWrapper .cls_for_descriptor (descriptor )
353
376
wrapped_descriptor = wrapper_cls (field_name , descriptor , self .attname )
354
377
setattr (sender , field_name , wrapped_descriptor )
@@ -426,7 +449,9 @@ def has_changed(self, field: str) -> bool:
426
449
if not self .instance .pk :
427
450
return True
428
451
elif field in self .saved_data :
429
- return self .previous (field ) != self .get_field_value (field )
452
+ prev : object = self .previous (field )
453
+ curr : object = self .get_field_value (field )
454
+ return prev != curr
430
455
else :
431
456
raise FieldError ('field "%s" not tracked' % field )
432
457
0 commit comments