Skip to content

Commit 52e243d

Browse files
committed
Add type argument to DescriptorWrapper
This preserves the type of the wrapped descriptor (usually a field). Maybe this is overkill, as `DescriptorWrapper` seems to only be used as part of the `FieldTracker` implementation and is not documented and barely tested. But technically, it is public API.
1 parent a57bca7 commit 52e243d

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed

model_utils/tracker.py

+41-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22

33
from copy import deepcopy
44
from functools import wraps
5-
from typing import TYPE_CHECKING, Any, Iterable, TypeVar, cast, overload
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
Generic,
9+
Iterable,
10+
Protocol,
11+
TypeVar,
12+
cast,
13+
overload,
14+
)
615

716
from django.core.exceptions import FieldError
817
from django.db import models
@@ -16,10 +25,22 @@ class _AugmentedModel(models.Model):
1625
_instance_initialized: bool
1726
_deferred_fields: set[str]
1827

19-
2028
T = TypeVar("T")
2129

2230

31+
class Descriptor(Protocol[T]):
32+
def __get__(self, instance: object, owner: type[object]) -> T:
33+
...
34+
35+
def __set__(self, instance: object, value: T) -> None:
36+
...
37+
38+
39+
class FullDescriptor(Descriptor[T]):
40+
def __delete__(self, instance: object) -> None:
41+
...
42+
43+
2344
class LightStateFieldFile(FieldFile):
2445
"""
2546
FieldFile subclass with the only aim to remove the instance from the state.
@@ -53,22 +74,22 @@ def lightweight_deepcopy(value: T) -> T:
5374
return deepcopy(value)
5475

5576

56-
class DescriptorWrapper:
77+
class DescriptorWrapper(Generic[T]):
5778

58-
def __init__(self, field_name: str, descriptor: models.Field, tracker_attname: str):
79+
def __init__(self, field_name: str, descriptor: Descriptor[T], tracker_attname: str):
5980
self.field_name = field_name
6081
self.descriptor = descriptor
6182
self.tracker_attname = tracker_attname
6283

6384
@overload
64-
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper:
85+
def __get__(self, instance: None, owner: type[models.Model]) -> DescriptorWrapper[T]:
6586
...
6687

6788
@overload
68-
def __get__(self, instance: models.Model, owner: type[models.Model]) -> models.Field:
89+
def __get__(self, instance: models.Model, owner: type[models.Model]) -> T:
6990
...
7091

71-
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper | models.Field:
92+
def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> DescriptorWrapper[T] | T:
7293
if instance is None:
7394
return self
7495
was_deferred = self.field_name in instance.get_deferred_fields()
@@ -78,7 +99,7 @@ def __get__(self, instance: models.Model | None, owner: type[models.Model]) -> D
7899
tracker_instance.saved_data[self.field_name] = lightweight_deepcopy(value)
79100
return value
80101

81-
def __set__(self, instance: models.Model, value: models.Field) -> None:
102+
def __set__(self, instance: models.Model, value: T) -> None:
82103
initialized = hasattr(instance, '_instance_initialized')
83104
was_deferred = self.field_name in instance.get_deferred_fields()
84105

@@ -101,23 +122,23 @@ def __set__(self, instance: models.Model, value: models.Field) -> None:
101122
else:
102123
instance.__dict__[self.field_name] = value
103124

104-
def __getattr__(self, attr: str) -> models.Field:
125+
def __getattr__(self, attr: str) -> T:
105126
return getattr(self.descriptor, attr)
106127

107128
@staticmethod
108-
def cls_for_descriptor(descriptor: models.Field) -> type[DescriptorWrapper]:
129+
def cls_for_descriptor(descriptor: Descriptor[T]) -> type[DescriptorWrapper[T]]:
109130
if hasattr(descriptor, '__delete__'):
110131
return FullDescriptorWrapper
111132
else:
112133
return DescriptorWrapper
113134

114135

115-
class FullDescriptorWrapper(DescriptorWrapper):
136+
class FullDescriptorWrapper(DescriptorWrapper[T]):
116137
"""
117138
Wrapper for descriptors with all three descriptor methods.
118139
"""
119-
def __delete__(self, obj: models.Field) -> None:
120-
self.descriptor.__delete__(obj) # type: ignore[attr-defined]
140+
def __delete__(self, obj: models.Model) -> None:
141+
cast(FullDescriptor[T], self.descriptor).__delete__(obj)
121142

122143

123144
class FieldsContext:
@@ -255,7 +276,9 @@ def has_changed(self, field: str) -> bool:
255276
# deferred fields haven't changed
256277
if field in self.deferred_fields and field not in self.instance.__dict__:
257278
return False
258-
return self.previous(field) != self.get_field_value(field)
279+
prev: object = self.previous(field)
280+
curr: object = self.get_field_value(field)
281+
return prev != curr
259282
else:
260283
raise FieldError('field "%s" not tracked' % field)
261284

@@ -348,7 +371,7 @@ def finalize_class(self, sender: type[models.Model], **kwargs: object) -> None:
348371
self.fields = (field.attname for field in sender._meta.fields)
349372
self.fields = set(self.fields)
350373
for field_name in self.fields:
351-
descriptor: models.Field = getattr(sender, field_name)
374+
descriptor: models.Field[Any, Any] = getattr(sender, field_name)
352375
wrapper_cls = DescriptorWrapper.cls_for_descriptor(descriptor)
353376
wrapped_descriptor = wrapper_cls(field_name, descriptor, self.attname)
354377
setattr(sender, field_name, wrapped_descriptor)
@@ -426,7 +449,9 @@ def has_changed(self, field: str) -> bool:
426449
if not self.instance.pk:
427450
return True
428451
elif field in self.saved_data:
429-
return self.previous(field) != self.get_field_value(field)
452+
prev: object = self.previous(field)
453+
curr: object = self.get_field_value(field)
454+
return prev != curr
430455
else:
431456
raise FieldError('field "%s" not tracked' % field)
432457

0 commit comments

Comments
 (0)