Skip to content

Commit f4653f0

Browse files
committed
Preserve tracked function's return type in FieldTracker
1 parent 23a756e commit f4653f0

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

model_utils/tracker.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -296,14 +296,30 @@ def __init__(self, fields: Iterable[str] | None = None):
296296
# finalize_class() will replace None; pretend it is never None.
297297
self.fields = cast(Iterable[str], fields)
298298

299+
@overload
300+
def __call__(
301+
self,
302+
func: None = None,
303+
fields: Iterable[str] | None = None
304+
) -> Callable[[Callable[..., T]], Callable[..., T]]:
305+
...
306+
307+
@overload
308+
def __call__(
309+
self,
310+
func: Callable[..., T],
311+
fields: Iterable[str] | None = None
312+
) -> Callable[..., T]:
313+
...
314+
299315
def __call__(
300316
self,
301-
func: Callable | None = None,
317+
func: Callable[..., T] | None = None,
302318
fields: Iterable[str] | None = None
303-
) -> Any:
304-
def decorator(f: Callable) -> Callable:
319+
) -> Callable[[Callable[..., T]], Callable[..., T]] | Callable[..., T]:
320+
def decorator(f: Callable[..., T]) -> Callable[..., T]:
305321
@wraps(f)
306-
def inner(obj: models.Model, *args: object, **kwargs: object) -> object:
322+
def inner(obj: models.Model, *args: object, **kwargs: object) -> T:
307323
tracker = getattr(obj, self.attname)
308324
field_list = tracker.fields if fields is None else fields
309325
with tracker(*field_list):

0 commit comments

Comments
 (0)