Skip to content

Commit 56d75b8

Browse files
wittekmJukkaL
authored andcommitted
Enum: deal with subclasses of Enum or IntEnum (#1601)
* Enum: deal with subclasses of Enum or IntEnum * better fullname() * Calculate MRO per-class instead of per-instance. Amend to trigger rebuild
1 parent cf9c4db commit 56d75b8

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

mypy/nodes.py

+15
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ def get_line(self) -> int: pass
5656
LITERAL_TYPE = 1
5757
LITERAL_NO = 0
5858

59+
# Hard coded name of Enum baseclass.
60+
ENUM_BASECLASS = "enum.Enum"
61+
5962
node_kinds = {
6063
LDEF: 'Ldef',
6164
GDEF: 'Gdef',
@@ -1873,6 +1876,18 @@ def calculate_mro(self) -> None:
18731876
mro = linearize_hierarchy(self)
18741877
assert mro, "Could not produce a MRO at all for %s" % (self,)
18751878
self.mro = mro
1879+
self.is_enum = self._calculate_is_enum()
1880+
1881+
def _calculate_is_enum(self) -> bool:
1882+
"""
1883+
If this is "enum.Enum" itself, then yes, it's an enum.
1884+
If the flag .is_enum has been set on anything in the MRO, it's an enum.
1885+
"""
1886+
if self.fullname() == ENUM_BASECLASS:
1887+
return True
1888+
if self.mro:
1889+
return any(type_info.is_enum for type_info in self.mro)
1890+
return False
18761891

18771892
def has_base(self, fullname: str) -> bool:
18781893
"""Return True if type has a base type with the specified name.

mypy/semanal.py

-13
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,6 @@
127127
'builtins.bytearray': 'builtins.str',
128128
})
129129

130-
# Hard coded list of Enum baseclasses.
131-
ENUM_BASECLASSES = [
132-
'enum.Enum',
133-
'enum.IntEnum',
134-
]
135-
136130
# When analyzing a function, should we analyze the whole function in one go, or
137131
# should we only perform one phase of the analysis? The latter is used for
138132
# nested functions. In the first phase we add the function to the symbol table
@@ -765,8 +759,6 @@ def analyze_base_classes(self, defn: ClassDef) -> None:
765759
defn.info.fallback_to_any = True
766760
elif not isinstance(base, UnboundType):
767761
self.fail('Invalid base class', base_expr)
768-
if isinstance(base, Instance):
769-
defn.info.is_enum = self.decide_is_enum(base)
770762
# Add 'object' as implicit base if there is no other base class.
771763
if (not defn.base_types and defn.fullname != 'builtins.object'):
772764
obj = self.object_type()
@@ -830,11 +822,6 @@ def is_base_class(self, t: TypeInfo, s: TypeInfo) -> bool:
830822
visited.add(base.type)
831823
return False
832824

833-
def decide_is_enum(self, instance: Instance) -> bool:
834-
"""Decide if a TypeInfo should be marked as .is_enum=True"""
835-
fullname = instance.type.fullname()
836-
return fullname in ENUM_BASECLASSES
837-
838825
def analyze_metaclass(self, defn: ClassDef) -> None:
839826
if defn.metaclass:
840827
sym = self.lookup_qualified(defn.metaclass, defn)

mypy/test/data/pythoneval-enum.test

+15
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,18 @@ Color.m2('')
103103
[out]
104104
_program.py:11: error: Argument 1 to "m" of "Color" has incompatible type "str"; expected "int"
105105
_program.py:12: error: Argument 1 to "m2" of "Color" has incompatible type "str"; expected "int"
106+
107+
[case testIntEnum_ExtendedIntEnum_functionTakingExtendedIntEnum]
108+
from enum import IntEnum
109+
class ExtendedIntEnum(IntEnum):
110+
pass
111+
class SomeExtIntEnum(ExtendedIntEnum):
112+
x = 1
113+
114+
def takes_int(i: int):
115+
pass
116+
takes_int(SomeExtIntEnum.x)
117+
118+
def takes_some_ext_int_enum(s: SomeExtIntEnum):
119+
pass
120+
takes_some_ext_int_enum(SomeExtIntEnum.x)

0 commit comments

Comments
 (0)