22
33from copy import deepcopy
44from functools import wraps
5- from typing import TYPE_CHECKING , Any , Iterable , TypeVar , cast , overload
5+ from typing import (
6+ TYPE_CHECKING ,
7+ Any ,
8+ Generic ,
9+ Iterable ,
10+ Protocol ,
11+ TypeVar ,
12+ cast ,
13+ overload ,
14+ )
615
716from django .core .exceptions import FieldError
817from django .db import models
@@ -16,10 +25,22 @@ class _AugmentedModel(models.Model):
1625 _instance_initialized : bool
1726 _deferred_fields : set [str ]
1827
19-
2028T = TypeVar ("T" )
2129
2230
31+ class Descriptor (Protocol [T ]):
32+ def __get__ (self , instance : object , owner : type [object ]) -> T :
33+ ...
34+
35+ def __set__ (self , instance : object , value : T ) -> None :
36+ ...
37+
38+
39+ class FullDescriptor (Descriptor [T ]):
40+ def __delete__ (self , instance : object ) -> None :
41+ ...
42+
43+
2344class LightStateFieldFile (FieldFile ):
2445 """
2546 FieldFile subclass with the only aim to remove the instance from the state.
@@ -53,22 +74,22 @@ def lightweight_deepcopy(value: T) -> T:
5374 return deepcopy (value )
5475
5576
56- class DescriptorWrapper :
77+ class DescriptorWrapper ( Generic [ T ]) :
5778
58- def __init__ (self , field_name : str , descriptor : models . Field , tracker_attname : str ):
79+ def __init__ (self , field_name : str , descriptor : Descriptor [ T ] , tracker_attname : str ):
5980 self .field_name = field_name
6081 self .descriptor = descriptor
6182 self .tracker_attname = tracker_attname
6283
6384 @overload
64- def __get__ (self , instance : None , owner : type [models .Model ]) -> DescriptorWrapper :
85+ def __get__ (self , instance : None , owner : type [models .Model ]) -> DescriptorWrapper [ T ] :
6586 ...
6687
6788 @overload
68- def __get__ (self , instance : models .Model , owner : type [models .Model ]) -> models . Field :
89+ def __get__ (self , instance : models .Model , owner : type [models .Model ]) -> T :
6990 ...
7091
71- def __get__ (self , instance : models .Model | None , owner : type [models .Model ]) -> DescriptorWrapper | models . Field :
92+ def __get__ (self , instance : models .Model | None , owner : type [models .Model ]) -> DescriptorWrapper [ T ] | T :
7293 if instance is None :
7394 return self
7495 was_deferred = self .field_name in instance .get_deferred_fields ()
@@ -78,7 +99,7 @@ def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> D
7899 tracker_instance .saved_data [self .field_name ] = lightweight_deepcopy (value )
79100 return value
80101
81- def __set__ (self , instance : models .Model , value : models . Field ) -> None :
102+ def __set__ (self , instance : models .Model , value : T ) -> None :
82103 initialized = hasattr (instance , '_instance_initialized' )
83104 was_deferred = self .field_name in instance .get_deferred_fields ()
84105
@@ -101,23 +122,23 @@ def __set__(self, instance: models.Model, value: models.Field) -> None:
101122 else :
102123 instance .__dict__ [self .field_name ] = value
103124
104- def __getattr__ (self , attr : str ) -> models . Field :
125+ def __getattr__ (self , attr : str ) -> T :
105126 return getattr (self .descriptor , attr )
106127
107128 @staticmethod
108- def cls_for_descriptor (descriptor : models . Field ) -> type [DescriptorWrapper ]:
129+ def cls_for_descriptor (descriptor : Descriptor [ T ] ) -> type [DescriptorWrapper [ T ] ]:
109130 if hasattr (descriptor , '__delete__' ):
110131 return FullDescriptorWrapper
111132 else :
112133 return DescriptorWrapper
113134
114135
115- class FullDescriptorWrapper (DescriptorWrapper ):
136+ class FullDescriptorWrapper (DescriptorWrapper [ T ] ):
116137 """
117138 Wrapper for descriptors with all three descriptor methods.
118139 """
119- def __delete__ (self , obj : models .Field ) -> None :
120- self .descriptor .__delete__ (obj ) # type: ignore[attr-defined]
140+ def __delete__ (self , obj : models .Model ) -> None :
141+ cast ( FullDescriptor [ T ], self .descriptor ) .__delete__ (obj )
121142
122143
123144class FieldsContext :
@@ -255,7 +276,9 @@ def has_changed(self, field: str) -> bool:
255276 # deferred fields haven't changed
256277 if field in self .deferred_fields and field not in self .instance .__dict__ :
257278 return False
258- return self .previous (field ) != self .get_field_value (field )
279+ prev : object = self .previous (field )
280+ curr : object = self .get_field_value (field )
281+ return prev != curr
259282 else :
260283 raise FieldError ('field "%s" not tracked' % field )
261284
@@ -348,7 +371,7 @@ def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None:
348371 self .fields = (field .attname for field in sender ._meta .fields )
349372 self .fields = set (self .fields )
350373 for field_name in self .fields :
351- descriptor : models .Field = getattr (sender , field_name )
374+ descriptor : models .Field [ Any , Any ] = getattr (sender , field_name )
352375 wrapper_cls = DescriptorWrapper .cls_for_descriptor (descriptor )
353376 wrapped_descriptor = wrapper_cls (field_name , descriptor , self .attname )
354377 setattr (sender , field_name , wrapped_descriptor )
@@ -426,7 +449,9 @@ def has_changed(self, field: str) -> bool:
426449 if not self .instance .pk :
427450 return True
428451 elif field in self .saved_data :
429- return self .previous (field ) != self .get_field_value (field )
452+ prev : object = self .previous (field )
453+ curr : object = self .get_field_value (field )
454+ return prev != curr
430455 else :
431456 raise FieldError ('field "%s" not tracked' % field )
432457
0 commit comments