Skip to content

Commit ec37d06

Browse files
authored
Extract through table creation to separate method (#2229)
1 parent c693e2a commit ec37d06

File tree

3 files changed

+128
-99
lines changed

3 files changed

+128
-99
lines changed

mypy_django_plugin/transformers/fields.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
1+
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast
22

33
from django.core.exceptions import FieldDoesNotExist
44
from django.db.models.fields import AutoField, Field
@@ -114,12 +114,17 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context
114114
)
115115

116116

117+
class FieldDescriptorTypes(NamedTuple):
118+
set: MypyType
119+
get: MypyType
120+
121+
117122
def get_field_descriptor_types(
118123
field_info: TypeInfo, *, is_set_nullable: bool, is_get_nullable: bool
119-
) -> Tuple[MypyType, MypyType]:
124+
) -> FieldDescriptorTypes:
120125
set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_set_nullable)
121126
get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_get_nullable)
122-
return set_type, get_type
127+
return FieldDescriptorTypes(set=set_type, get=get_type)
123128

124129

125130
def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:

mypy_django_plugin/transformers/manytomany.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import NamedTuple, Optional, Tuple, Union
22

33
from mypy.checker import TypeChecker
4-
from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, RefExpr, StrExpr, TypeInfo
4+
from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, Node, RefExpr, StrExpr, TypeInfo
55
from mypy.plugin import FunctionContext, MethodContext
66
from mypy.semanal import SemanticAnalyzer
77
from mypy.types import Instance, ProperType, TypeVarType, UninhabitedType
@@ -12,12 +12,12 @@
1212

1313

1414
class M2MThrough(NamedTuple):
15-
arg: Optional[Expression]
15+
arg: Optional[Node]
1616
model: ProperType
1717

1818

1919
class M2MTo(NamedTuple):
20-
arg: Expression
20+
arg: Node
2121
model: ProperType
2222
self: bool # ManyToManyField('self', ...)
2323

mypy_django_plugin/transformers/models.py

Lines changed: 117 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from mypy_django_plugin.exceptions import UnregisteredModelError
3737
from mypy_django_plugin.lib import fullnames, helpers
3838
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME
39-
from mypy_django_plugin.transformers.fields import get_field_descriptor_types
39+
from mypy_django_plugin.transformers.fields import FieldDescriptorTypes, get_field_descriptor_types
4040
from mypy_django_plugin.transformers.managers import (
4141
MANAGER_METHODS_RETURNING_QUERYSET,
4242
create_manager_info_from_from_queryset_call,
@@ -644,17 +644,6 @@ def run(self) -> None:
644644
# TODO: Create abstract through models?
645645
return
646646

647-
# Start out by prefetching a couple of dependencies needed to be able to declare any
648-
# new, implicit, through model class.
649-
model_base = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME)
650-
fk_field = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME)
651-
manager_info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
652-
if model_base is None or fk_field is None or manager_info is None:
653-
raise helpers.IncompleteDefnException()
654-
655-
from_pk = self.get_pk_instance(self.model_classdef.info)
656-
fk_set_type, fk_get_type = get_field_descriptor_types(fk_field, is_set_nullable=False, is_get_nullable=False)
657-
658647
for statement in self.statements():
659648
# Check if this part of the class body is an assignment from a 'ManyToManyField' call
660649
# <field> = ManyToManyField(...)
@@ -675,90 +664,16 @@ def run(self) -> None:
675664
continue
676665
# Resolve argument information of the 'ManyToManyField(...)' call
677666
args = self.resolve_many_to_many_arguments(statement.rvalue, context=statement)
678-
if (
679-
# Ignore calls without required 'to' argument, mypy will complain
680-
args is None
681-
or not isinstance(args.to.model, Instance)
682-
# Call has explicit 'through=', no need to create any implicit through table
683-
or args.through is not None
684-
):
667+
# Ignore calls without required 'to' argument, mypy will complain
668+
if args is None:
685669
continue
686-
687670
# Get the names of the implicit through model that will be generated
688671
through_model_name = f"{self.model_classdef.name}_{m2m_field_name}"
689-
through_model_fullname = f"{self.model_classdef.info.module_name}.{through_model_name}"
690-
# If implicit through model is already declared there's nothing more we should do
691-
through_model = self.lookup_typeinfo(through_model_fullname)
692-
if through_model is not None:
693-
continue
694-
# Declare a new, empty, implicitly generated through model class named: '<Model>_<field_name>'
695-
through_model = self.add_new_class_for_current_module(
696-
through_model_name, bases=[Instance(model_base, [])]
697-
)
698-
# We attempt to be a bit clever here and store the generated through model's fullname in
699-
# the metadata of the class containing the 'ManyToManyField' call expression, where its
700-
# identifier is the field name of the 'ManyToManyField'. This would allow the containing
701-
# model to always find the implicit through model, so that it doesn't get lost.
702-
model_metadata = helpers.get_django_metadata(self.model_classdef.info)
703-
model_metadata.setdefault("m2m_throughs", {})
704-
model_metadata["m2m_throughs"][m2m_field_name] = through_model.fullname
705-
# Add a 'pk' symbol to the model class
706-
helpers.add_new_sym_for_info(
707-
through_model, name="pk", sym_type=self.default_pk_instance.copy_modified()
708-
)
709-
# Add an 'id' symbol to the model class
710-
helpers.add_new_sym_for_info(
711-
through_model, name="id", sym_type=self.default_pk_instance.copy_modified()
712-
)
713-
# Add the foreign key to the model containing the 'ManyToManyField' call:
714-
# <containing_model> or from_<model>
715-
from_name = (
716-
f"from_{self.model_classdef.name.lower()}" if args.to.self else self.model_classdef.name.lower()
717-
)
718-
helpers.add_new_sym_for_info(
719-
through_model,
720-
name=from_name,
721-
sym_type=Instance(
722-
fk_field,
723-
[
724-
helpers.convert_any_to_type(fk_set_type, Instance(self.model_classdef.info, [])),
725-
helpers.convert_any_to_type(fk_get_type, Instance(self.model_classdef.info, [])),
726-
],
727-
),
728-
)
729-
# Add the foreign key's '_id' field: <containing_model>_id or from_<model>_id
730-
helpers.add_new_sym_for_info(through_model, name=f"{from_name}_id", sym_type=from_pk.copy_modified())
731-
# Add the foreign key to the model on the opposite side of the relation
732-
# i.e. the model given as 'to' argument to the 'ManyToManyField' call:
733-
# <other_model> or to_<model>
734-
to_name = f"to_{args.to.model.type.name.lower()}" if args.to.self else args.to.model.type.name.lower()
735-
helpers.add_new_sym_for_info(
736-
through_model,
737-
name=to_name,
738-
sym_type=Instance(
739-
fk_field,
740-
[
741-
helpers.convert_any_to_type(fk_set_type, args.to.model),
742-
helpers.convert_any_to_type(fk_get_type, args.to.model),
743-
],
744-
),
745-
)
746-
# Add the foreign key's '_id' field: <other_model>_id or to_<model>_id
747-
other_pk = self.get_pk_instance(args.to.model.type)
748-
helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified())
749-
# Add a manager named 'objects'
750-
helpers.add_new_sym_for_info(
751-
through_model,
752-
name="objects",
753-
sym_type=Instance(manager_info, [Instance(through_model, [])]),
754-
is_classvar=True,
755-
)
756-
# Also add manager as '_default_manager' attribute
757-
helpers.add_new_sym_for_info(
758-
through_model,
759-
name="_default_manager",
760-
sym_type=Instance(manager_info, [Instance(through_model, [])]),
761-
is_classvar=True,
672+
self.create_through_table_class(
673+
field_name=m2m_field_name,
674+
model_name=through_model_name,
675+
model_fullname=f"{self.model_classdef.info.module_name}.{through_model_name}",
676+
m2m_args=args,
762677
)
763678

764679
@cached_property
@@ -771,6 +686,35 @@ def default_pk_instance(self) -> Instance:
771686
list(get_field_descriptor_types(default_pk_field, is_set_nullable=True, is_get_nullable=False)),
772687
)
773688

689+
@cached_property
690+
def model_pk_instance(self) -> Instance:
691+
return self.get_pk_instance(self.model_classdef.info)
692+
693+
@cached_property
694+
def model_base(self) -> TypeInfo:
695+
info = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME)
696+
if info is None:
697+
raise helpers.IncompleteDefnException()
698+
return info
699+
700+
@cached_property
701+
def fk_field(self) -> TypeInfo:
702+
info = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME)
703+
if info is None:
704+
raise helpers.IncompleteDefnException()
705+
return info
706+
707+
@cached_property
708+
def manager_info(self) -> TypeInfo:
709+
info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME)
710+
if info is None:
711+
raise helpers.IncompleteDefnException()
712+
return info
713+
714+
@cached_property
715+
def fk_field_types(self) -> FieldDescriptorTypes:
716+
return get_field_descriptor_types(self.fk_field, is_set_nullable=False, is_get_nullable=False)
717+
774718
def get_pk_instance(self, model: TypeInfo, /) -> Instance:
775719
"""
776720
Get a primary key instance of provided model's type info. If primary key can't be resolved,
@@ -783,6 +727,86 @@ def get_pk_instance(self, model: TypeInfo, /) -> Instance:
783727
return pk.type
784728
return self.default_pk_instance
785729

730+
def create_through_table_class(
731+
self, field_name: str, model_name: str, model_fullname: str, m2m_args: M2MArguments
732+
) -> None:
733+
if (
734+
not isinstance(m2m_args.to.model, Instance)
735+
# Call has explicit 'through=', no need to create any implicit through table
736+
or m2m_args.through is not None
737+
):
738+
return
739+
740+
# If through model is already declared there's nothing more we should do
741+
through_model = self.lookup_typeinfo(model_fullname)
742+
if through_model is not None:
743+
return
744+
# Declare a new, empty, implicitly generated through model class named: '<Model>_<field_name>'
745+
through_model = self.add_new_class_for_current_module(model_name, bases=[Instance(self.model_base, [])])
746+
# We attempt to be a bit clever here and store the generated through model's fullname in
747+
# the metadata of the class containing the 'ManyToManyField' call expression, where its
748+
# identifier is the field name of the 'ManyToManyField'. This would allow the containing
749+
# model to always find the implicit through model, so that it doesn't get lost.
750+
model_metadata = helpers.get_django_metadata(self.model_classdef.info)
751+
model_metadata.setdefault("m2m_throughs", {})
752+
model_metadata["m2m_throughs"][field_name] = through_model.fullname
753+
# Add a 'pk' symbol to the model class
754+
helpers.add_new_sym_for_info(through_model, name="pk", sym_type=self.default_pk_instance.copy_modified())
755+
# Add an 'id' symbol to the model class
756+
helpers.add_new_sym_for_info(through_model, name="id", sym_type=self.default_pk_instance.copy_modified())
757+
# Add the foreign key to the model containing the 'ManyToManyField' call:
758+
# <containing_model> or from_<model>
759+
from_name = f"from_{self.model_classdef.name.lower()}" if m2m_args.to.self else self.model_classdef.name.lower()
760+
helpers.add_new_sym_for_info(
761+
through_model,
762+
name=from_name,
763+
sym_type=Instance(
764+
self.fk_field,
765+
[
766+
helpers.convert_any_to_type(self.fk_field_types.set, Instance(self.model_classdef.info, [])),
767+
helpers.convert_any_to_type(self.fk_field_types.get, Instance(self.model_classdef.info, [])),
768+
],
769+
),
770+
)
771+
# Add the foreign key's '_id' field: <containing_model>_id or from_<model>_id
772+
helpers.add_new_sym_for_info(
773+
through_model, name=f"{from_name}_id", sym_type=self.model_pk_instance.copy_modified()
774+
)
775+
# Add the foreign key to the model on the opposite side of the relation
776+
# i.e. the model given as 'to' argument to the 'ManyToManyField' call:
777+
# <other_model> or to_<model>
778+
to_name = (
779+
f"to_{m2m_args.to.model.type.name.lower()}" if m2m_args.to.self else m2m_args.to.model.type.name.lower()
780+
)
781+
helpers.add_new_sym_for_info(
782+
through_model,
783+
name=to_name,
784+
sym_type=Instance(
785+
self.fk_field,
786+
[
787+
helpers.convert_any_to_type(self.fk_field_types.set, m2m_args.to.model),
788+
helpers.convert_any_to_type(self.fk_field_types.get, m2m_args.to.model),
789+
],
790+
),
791+
)
792+
# Add the foreign key's '_id' field: <other_model>_id or to_<model>_id
793+
other_pk = self.get_pk_instance(m2m_args.to.model.type)
794+
helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified())
795+
# Add a manager named 'objects'
796+
helpers.add_new_sym_for_info(
797+
through_model,
798+
name="objects",
799+
sym_type=Instance(self.manager_info, [Instance(through_model, [])]),
800+
is_classvar=True,
801+
)
802+
# Also add manager as '_default_manager' attribute
803+
helpers.add_new_sym_for_info(
804+
through_model,
805+
name="_default_manager",
806+
sym_type=Instance(self.manager_info, [Instance(through_model, [])]),
807+
is_classvar=True,
808+
)
809+
786810
def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> Optional[M2MArguments]:
787811
"""
788812
Inspect a 'ManyToManyField(...)' call to collect argument data on any 'to' and

0 commit comments

Comments
 (0)