Skip to content

Commit 7914b2d

Browse files
authored
Fix mypyc crash with enum type aliases (#18725)
mypyc was crashing because it couldn't find the type in the type map. This PR adds a generic AnyType to the type map if an expression isn't in the map already. Tried actually changing mypy to accept these type alias expressions, but ran into problems with nested type aliases where the inner one doesn't have the "analyzed" value and ending up with wrong results. fixes mypyc/mypyc#1064
1 parent 256cf68 commit 7914b2d

File tree

5 files changed

+53
-2
lines changed

5 files changed

+53
-2
lines changed

mypyc/irbuild/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def build_ir(
7373

7474
for module in modules:
7575
# First pass to determine free symbols.
76-
pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove)
76+
pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove, types)
7777
module.accept(pbv)
7878

7979
# Construct and configure builder objects (cyclic runtime dependency).

mypyc/irbuild/missingtypevisitor.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
3+
from mypy.nodes import Expression, Node
4+
from mypy.traverser import ExtendedTraverserVisitor
5+
from mypy.types import AnyType, Type, TypeOfAny
6+
7+
8+
class MissingTypesVisitor(ExtendedTraverserVisitor):
9+
"""AST visitor that can be used to add any missing types as a generic AnyType."""
10+
11+
def __init__(self, types: dict[Expression, Type]) -> None:
12+
super().__init__()
13+
self.types: dict[Expression, Type] = types
14+
15+
def visit(self, o: Node) -> bool:
16+
if isinstance(o, Expression) and o not in self.types:
17+
self.types[o] = AnyType(TypeOfAny.special_form)
18+
19+
# If returns True, will continue to nested nodes.
20+
return True

mypyc/irbuild/prebuildvisitor.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from mypy.nodes import (
4+
AssignmentStmt,
45
Block,
56
Decorator,
67
Expression,
@@ -16,7 +17,9 @@
1617
Var,
1718
)
1819
from mypy.traverser import ExtendedTraverserVisitor
20+
from mypy.types import Type
1921
from mypyc.errors import Errors
22+
from mypyc.irbuild.missingtypevisitor import MissingTypesVisitor
2023

2124

2225
class PreBuildVisitor(ExtendedTraverserVisitor):
@@ -39,6 +42,7 @@ def __init__(
3942
errors: Errors,
4043
current_file: MypyFile,
4144
decorators_to_remove: dict[FuncDef, list[int]],
45+
types: dict[Expression, Type],
4246
) -> None:
4347
super().__init__()
4448
# Dict from a function to symbols defined directly in the
@@ -82,11 +86,20 @@ def __init__(
8286

8387
self.current_file: MypyFile = current_file
8488

89+
self.missing_types_visitor = MissingTypesVisitor(types)
90+
8591
def visit(self, o: Node) -> bool:
8692
if not isinstance(o, Import):
8793
self._current_import_group = None
8894
return True
8995

96+
def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None:
97+
# These are cases where mypy may not have types for certain expressions,
98+
# but mypyc needs some form type to exist.
99+
if stmt.is_alias_def:
100+
stmt.rvalue.accept(self.missing_types_visitor)
101+
return super().visit_assignment_stmt(stmt)
102+
90103
def visit_block(self, block: Block) -> None:
91104
self._current_import_group = None
92105
super().visit_block(block)

mypyc/test-data/irbuild-classes.test

+10
Original file line numberDiff line numberDiff line change
@@ -1335,3 +1335,13 @@ def outer():
13351335
if True:
13361336
class OtherInner: # E: Nested class definitions not supported
13371337
pass
1338+
1339+
[case testEnumClassAlias]
1340+
from enum import Enum
1341+
from typing import Literal, Union
1342+
1343+
class SomeEnum(Enum):
1344+
AVALUE = "a"
1345+
1346+
ALIAS = Literal[SomeEnum.AVALUE]
1347+
ALIAS2 = Union[Literal[SomeEnum.AVALUE], None]

mypyc/test-data/run-python312.test

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[case testPEP695Basics]
2-
from typing import Any, TypeAliasType, cast
2+
from enum import Enum
3+
from typing import Any, Literal, TypeAliasType, cast
34

45
from testutil import assertRaises
56

@@ -188,6 +189,13 @@ type R = int | list[R]
188189
def test_recursive_type_alias() -> None:
189190
assert isinstance(R, TypeAliasType)
190191
assert getattr(R, "__value__") == (int | list[R])
192+
193+
class SomeEnum(Enum):
194+
AVALUE = "a"
195+
196+
type EnumLiteralAlias1 = Literal[SomeEnum.AVALUE]
197+
type EnumLiteralAlias2 = Literal[SomeEnum.AVALUE] | None
198+
EnumLiteralAlias3 = Literal[SomeEnum.AVALUE] | None
191199
[typing fixtures/typing-full.pyi]
192200

193201
[case testPEP695GenericTypeAlias]

0 commit comments

Comments
 (0)