Skip to content

Commit 1db7d6b

Browse files
committed
Fix type generics in InheritanceIterable
1 parent f4653f0 commit 1db7d6b

File tree

1 file changed

+43
-35
lines changed

1 file changed

+43
-35
lines changed

model_utils/managers.py

+43-35
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from django.db import connection, models
88
from django.db.models.constants import LOOKUP_SEP
99
from django.db.models.fields.related import OneToOneField, OneToOneRel
10+
from django.db.models.query import ModelIterable, QuerySet
1011
from django.db.models.sql.datastructures import Join
1112

1213
ModelT = TypeVar('ModelT', bound=models.Model, covariant=True)
@@ -15,44 +16,51 @@
1516
from collections.abc import Iterator
1617

1718
from django.db.models.query import BaseIterable
18-
from django.db.models.query import ModelIterable as ModelIterableGeneric
19-
from django.db.models.query import QuerySet as QuerySetGeneric
2019

21-
ModelIterable = ModelIterableGeneric[ModelT]
22-
QuerySet = QuerySetGeneric[ModelT]
23-
else:
24-
from django.db.models.query import ModelIterable, QuerySet
25-
26-
27-
class InheritanceIterable(ModelIterable):
28-
def __iter__(self) -> Iterator[ModelT]:
29-
queryset = self.queryset
30-
iter: ModelIterableGeneric[ModelT] = ModelIterable(queryset)
31-
if hasattr(queryset, 'subclasses'):
32-
assert hasattr(queryset, '_get_sub_obj_recurse')
33-
extras = tuple(queryset.query.extra.keys())
34-
# sort the subclass names longest first,
35-
# so with 'a' and 'a__b' it goes as deep as possible
36-
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
37-
for obj in iter:
38-
sub_obj = None
39-
for s in subclasses:
40-
sub_obj = queryset._get_sub_obj_recurse(obj, s)
41-
if sub_obj:
42-
break
43-
if not sub_obj:
44-
sub_obj = obj
45-
46-
if hasattr(queryset, '_annotated'):
47-
for k in queryset._annotated:
48-
setattr(sub_obj, k, getattr(obj, k))
49-
50-
for k in extras:
20+
21+
def _iter_inheritance_queryset(queryset: QuerySet[ModelT]) -> Iterator[ModelT]:
22+
iter: ModelIterable[ModelT] = ModelIterable(queryset)
23+
if hasattr(queryset, 'subclasses'):
24+
assert hasattr(queryset, '_get_sub_obj_recurse')
25+
extras = tuple(queryset.query.extra.keys())
26+
# sort the subclass names longest first,
27+
# so with 'a' and 'a__b' it goes as deep as possible
28+
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
29+
for obj in iter:
30+
sub_obj = None
31+
for s in subclasses:
32+
sub_obj = queryset._get_sub_obj_recurse(obj, s)
33+
if sub_obj:
34+
break
35+
if not sub_obj:
36+
sub_obj = obj
37+
38+
if hasattr(queryset, '_annotated'):
39+
for k in queryset._annotated:
5140
setattr(sub_obj, k, getattr(obj, k))
5241

53-
yield sub_obj
54-
else:
55-
yield from iter
42+
for k in extras:
43+
setattr(sub_obj, k, getattr(obj, k))
44+
45+
yield sub_obj
46+
else:
47+
yield from iter
48+
49+
50+
if TYPE_CHECKING:
51+
class InheritanceIterable(ModelIterable[ModelT]):
52+
queryset: QuerySet[ModelT]
53+
54+
def __init__(self, queryset: QuerySet[ModelT], *args: Any, **kwargs: Any):
55+
...
56+
57+
def __iter__(self) -> Iterator[ModelT]:
58+
...
59+
60+
else:
61+
class InheritanceIterable(ModelIterable):
62+
def __iter__(self):
63+
return _iter_inheritance_queryset(self.queryset)
5664

5765

5866
class InheritanceQuerySetMixin(Generic[ModelT]):

0 commit comments

Comments
 (0)