Skip to content

Commit 584f8f3

Browse files
authored
Improve type checking of generic function bodies (#1580)
This commit series gives type variable ids their own type and replaces the class/function type variable distinction with a plain/metavariable distinction. The main goal is to fix #603, but it's also progress towards #1261 and other bugs involving type variable inference. Metavariables (or unification variables) are variables introduced during type inference to represent the types that will be substituted for generic class or function type parameters. They only exist during type inference and should never escape into the inferred type of identifiers. Fixes #603.
1 parent a3f002f commit 584f8f3

19 files changed

+235
-223
lines changed

mypy/applytype.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import mypy.subtypes
44
from mypy.sametypes import is_same_type
55
from mypy.expandtype import expand_type
6-
from mypy.types import Type, TypeVarType, CallableType, AnyType, Void
6+
from mypy.types import Type, TypeVarId, TypeVarType, CallableType, AnyType, Void
77
from mypy.messages import MessageBuilder
88
from mypy.nodes import Context
99

@@ -48,7 +48,7 @@ def apply_generic_arguments(callable: CallableType, types: List[Type],
4848
msg.incompatible_typevar_value(callable, i + 1, type, context)
4949

5050
# Create a map from type variable id to target type.
51-
id_to_type = {} # type: Dict[int, Type]
51+
id_to_type = {} # type: Dict[TypeVarId, Type]
5252
for i, tv in enumerate(tvars):
5353
if types[i]:
5454
id_to_type[tv.id] = types[i]

mypy/checker.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from mypy.types import (
3333
Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType,
3434
Instance, NoneTyp, ErrorType, strip_type,
35-
UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType
35+
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType
3636
)
3737
from mypy.sametypes import is_same_type
3838
from mypy.messages import MessageBuilder
@@ -920,7 +920,7 @@ def check_getattr_method(self, typ: CallableType, context: Context) -> None:
920920
def expand_typevars(self, defn: FuncItem,
921921
typ: CallableType) -> List[Tuple[FuncItem, CallableType]]:
922922
# TODO use generator
923-
subst = [] # type: List[List[Tuple[int, Type]]]
923+
subst = [] # type: List[List[Tuple[TypeVarId, Type]]]
924924
tvars = typ.variables or []
925925
tvars = tvars[:]
926926
if defn.info:
@@ -2524,17 +2524,17 @@ def get_isinstance_type(node: Node, type_map: Dict[Node, Type]) -> Type:
25242524
return UnionType(types)
25252525

25262526

2527-
def expand_node(defn: Node, map: Dict[int, Type]) -> Node:
2527+
def expand_node(defn: Node, map: Dict[TypeVarId, Type]) -> Node:
25282528
visitor = TypeTransformVisitor(map)
25292529
return defn.accept(visitor)
25302530

25312531

2532-
def expand_func(defn: FuncItem, map: Dict[int, Type]) -> FuncItem:
2532+
def expand_func(defn: FuncItem, map: Dict[TypeVarId, Type]) -> FuncItem:
25332533
return cast(FuncItem, expand_node(defn, map))
25342534

25352535

25362536
class TypeTransformVisitor(TransformVisitor):
2537-
def __init__(self, map: Dict[int, Type]) -> None:
2537+
def __init__(self, map: Dict[TypeVarId, Type]) -> None:
25382538
super().__init__()
25392539
self.map = map
25402540

mypy/checkexpr.py

+38-21
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mypy.types import (
66
Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef,
7-
TupleType, Instance, TypeVarType, ErasedType, UnionType,
7+
TupleType, Instance, TypeVarId, TypeVarType, ErasedType, UnionType,
88
PartialType, DeletedType, UnboundType, UninhabitedType, TypeType
99
)
1010
from mypy.nodes import (
@@ -22,7 +22,7 @@
2222
import mypy.checker
2323
from mypy import types
2424
from mypy.sametypes import is_same_type
25-
from mypy.replacetvars import replace_func_type_vars
25+
from mypy.erasetype import replace_meta_vars
2626
from mypy.messages import MessageBuilder
2727
from mypy import messages
2828
from mypy.infer import infer_type_arguments, infer_function_type_arguments
@@ -34,6 +34,7 @@
3434
from mypy.semanal import self_type
3535
from mypy.constraints import get_actual_type
3636
from mypy.checkstrformat import StringFormatterChecker
37+
from mypy.expandtype import expand_type
3738

3839
from mypy import experiments
3940

@@ -234,6 +235,7 @@ def check_call(self, callee: Type, args: List[Node],
234235
lambda i: self.accept(args[i]))
235236

236237
if callee.is_generic():
238+
callee = freshen_generic_callable(callee)
237239
callee = self.infer_function_type_arguments_using_context(
238240
callee, context)
239241
callee = self.infer_function_type_arguments(
@@ -394,12 +396,12 @@ def infer_function_type_arguments_using_context(
394396
ctx = self.chk.type_context[-1]
395397
if not ctx:
396398
return callable
397-
# The return type may have references to function type variables that
399+
# The return type may have references to type metavariables that
398400
# we are inferring right now. We must consider them as indeterminate
399401
# and they are not potential results; thus we replace them with the
400402
# special ErasedType type. On the other hand, class type variables are
401403
# valid results.
402-
erased_ctx = replace_func_type_vars(ctx, ErasedType())
404+
erased_ctx = replace_meta_vars(ctx, ErasedType())
403405
ret_type = callable.ret_type
404406
if isinstance(ret_type, TypeVarType):
405407
if ret_type.values or (not isinstance(ctx, Instance) or
@@ -1264,15 +1266,16 @@ def visit_set_expr(self, e: SetExpr) -> Type:
12641266
def check_list_or_set_expr(self, items: List[Node], fullname: str,
12651267
tag: str, context: Context) -> Type:
12661268
# Translate into type checking a generic function call.
1267-
tv = TypeVarType('T', -1, [], self.chk.object_type())
1269+
tvdef = TypeVarDef('T', -1, [], self.chk.object_type())
1270+
tv = TypeVarType(tvdef)
12681271
constructor = CallableType(
12691272
[tv],
12701273
[nodes.ARG_STAR],
12711274
[None],
12721275
self.chk.named_generic_type(fullname, [tv]),
12731276
self.named_type('builtins.function'),
12741277
name=tag,
1275-
variables=[TypeVarDef('T', -1, None, self.chk.object_type())])
1278+
variables=[tvdef])
12761279
return self.check_call(constructor,
12771280
items,
12781281
[nodes.ARG_POS] * len(items), context)[0]
@@ -1301,20 +1304,21 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type:
13011304

13021305
def visit_dict_expr(self, e: DictExpr) -> Type:
13031306
# Translate into type checking a generic function call.
1304-
tv1 = TypeVarType('KT', -1, [], self.chk.object_type())
1305-
tv2 = TypeVarType('VT', -2, [], self.chk.object_type())
1307+
ktdef = TypeVarDef('KT', -1, [], self.chk.object_type())
1308+
vtdef = TypeVarDef('VT', -2, [], self.chk.object_type())
1309+
kt = TypeVarType(ktdef)
1310+
vt = TypeVarType(vtdef)
13061311
# The callable type represents a function like this:
13071312
#
13081313
# def <unnamed>(*v: Tuple[kt, vt]) -> Dict[kt, vt]: ...
13091314
constructor = CallableType(
1310-
[TupleType([tv1, tv2], self.named_type('builtins.tuple'))],
1315+
[TupleType([kt, vt], self.named_type('builtins.tuple'))],
13111316
[nodes.ARG_STAR],
13121317
[None],
1313-
self.chk.named_generic_type('builtins.dict', [tv1, tv2]),
1318+
self.chk.named_generic_type('builtins.dict', [kt, vt]),
13141319
self.named_type('builtins.function'),
13151320
name='<list>',
1316-
variables=[TypeVarDef('KT', -1, None, self.chk.object_type()),
1317-
TypeVarDef('VT', -2, None, self.chk.object_type())])
1321+
variables=[ktdef, vtdef])
13181322
# Synthesize function arguments.
13191323
args = [] # type: List[Node]
13201324
for key, value in e.items:
@@ -1360,7 +1364,7 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> CallableType:
13601364
# they must be considered as indeterminate. We use ErasedType since it
13611365
# does not affect type inference results (it is for purposes like this
13621366
# only).
1363-
ctx = replace_func_type_vars(ctx, ErasedType())
1367+
ctx = replace_meta_vars(ctx, ErasedType())
13641368

13651369
callable_ctx = cast(CallableType, ctx)
13661370

@@ -1438,15 +1442,16 @@ def check_generator_or_comprehension(self, gen: GeneratorExpr,
14381442

14391443
# Infer the type of the list comprehension by using a synthetic generic
14401444
# callable type.
1441-
tv = TypeVarType('T', -1, [], self.chk.object_type())
1445+
tvdef = TypeVarDef('T', -1, [], self.chk.object_type())
1446+
tv = TypeVarType(tvdef)
14421447
constructor = CallableType(
14431448
[tv],
14441449
[nodes.ARG_POS],
14451450
[None],
14461451
self.chk.named_generic_type(type_name, [tv]),
14471452
self.chk.named_type('builtins.function'),
14481453
name=id_for_messages,
1449-
variables=[TypeVarDef('T', -1, None, self.chk.object_type())])
1454+
variables=[tvdef])
14501455
return self.check_call(constructor,
14511456
[gen.left_expr], [nodes.ARG_POS], gen)[0]
14521457

@@ -1456,17 +1461,18 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension):
14561461

14571462
# Infer the type of the list comprehension by using a synthetic generic
14581463
# callable type.
1459-
key_tv = TypeVarType('KT', -1, [], self.chk.object_type())
1460-
value_tv = TypeVarType('VT', -2, [], self.chk.object_type())
1464+
ktdef = TypeVarDef('KT', -1, [], self.chk.object_type())
1465+
vtdef = TypeVarDef('VT', -2, [], self.chk.object_type())
1466+
kt = TypeVarType(ktdef)
1467+
vt = TypeVarType(vtdef)
14611468
constructor = CallableType(
1462-
[key_tv, value_tv],
1469+
[kt, vt],
14631470
[nodes.ARG_POS, nodes.ARG_POS],
14641471
[None, None],
1465-
self.chk.named_generic_type('builtins.dict', [key_tv, value_tv]),
1472+
self.chk.named_generic_type('builtins.dict', [kt, vt]),
14661473
self.chk.named_type('builtins.function'),
14671474
name='<dictionary-comprehension>',
1468-
variables=[TypeVarDef('KT', -1, None, self.chk.object_type()),
1469-
TypeVarDef('VT', -2, None, self.chk.object_type())])
1475+
variables=[ktdef, vtdef])
14701476
return self.check_call(constructor,
14711477
[e.key, e.value], [nodes.ARG_POS, nodes.ARG_POS], e)[0]
14721478

@@ -1775,3 +1781,14 @@ def overload_arg_similarity(actual: Type, formal: Type) -> int:
17751781
return 2
17761782
# Fall back to a conservative equality check for the remaining kinds of type.
17771783
return 2 if is_same_type(erasetype.erase_type(actual), erasetype.erase_type(formal)) else 0
1784+
1785+
1786+
def freshen_generic_callable(callee: CallableType) -> CallableType:
1787+
tvdefs = []
1788+
tvmap = {} # type: Dict[TypeVarId, Type]
1789+
for v in callee.variables:
1790+
tvdef = TypeVarDef.new_unification_variable(v)
1791+
tvdefs.append(tvdef)
1792+
tvmap[v.id] = TypeVarType(tvdef)
1793+
1794+
return cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvdefs)

mypy/checkmember.py

+6-42
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""Type checking of attribute access"""
22

3-
from typing import cast, Callable, List, Optional
3+
from typing import cast, Callable, List, Dict, Optional
44

55
from mypy.types import (
6-
Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarDef,
6+
Type, Instance, AnyType, TupleType, CallableType, FunctionLike, TypeVarId, TypeVarDef,
77
Overloaded, TypeVarType, TypeTranslator, UnionType, PartialType,
88
DeletedType, NoneTyp, TypeType
99
)
@@ -413,51 +413,15 @@ def class_callable(init_type: CallableType, info: TypeInfo, type_type: Instance,
413413
special_sig: Optional[str]) -> CallableType:
414414
"""Create a type object type based on the signature of __init__."""
415415
variables = [] # type: List[TypeVarDef]
416-
for i, tvar in enumerate(info.defn.type_vars):
417-
variables.append(TypeVarDef(tvar.name, i + 1, tvar.values, tvar.upper_bound,
418-
tvar.variance))
419-
420-
initvars = init_type.variables
421-
variables.extend(initvars)
416+
variables.extend(info.defn.type_vars)
417+
variables.extend(init_type.variables)
422418

423419
callable_type = init_type.copy_modified(
424420
ret_type=self_type(info), fallback=type_type, name=None, variables=variables,
425421
special_sig=special_sig)
426422
c = callable_type.with_name('"{}"'.format(info.name()))
427-
cc = convert_class_tvars_to_func_tvars(c, len(initvars))
428-
cc.is_classmethod_class = True
429-
return cc
430-
431-
432-
def convert_class_tvars_to_func_tvars(callable: CallableType,
433-
num_func_tvars: int) -> CallableType:
434-
return cast(CallableType, callable.accept(TvarTranslator(num_func_tvars)))
435-
436-
437-
class TvarTranslator(TypeTranslator):
438-
def __init__(self, num_func_tvars: int) -> None:
439-
super().__init__()
440-
self.num_func_tvars = num_func_tvars
441-
442-
def visit_type_var(self, t: TypeVarType) -> Type:
443-
if t.id < 0:
444-
return t
445-
else:
446-
return TypeVarType(t.name, -t.id - self.num_func_tvars, t.values, t.upper_bound,
447-
t.variance)
448-
449-
def translate_variables(self,
450-
variables: List[TypeVarDef]) -> List[TypeVarDef]:
451-
if not variables:
452-
return variables
453-
items = [] # type: List[TypeVarDef]
454-
for v in variables:
455-
if v.id > 0:
456-
items.append(TypeVarDef(v.name, -v.id - self.num_func_tvars,
457-
v.values, v.upper_bound, v.variance))
458-
else:
459-
items.append(v)
460-
return items
423+
c.is_classmethod_class = True
424+
return c
461425

462426

463427
def map_type_from_supertype(typ: Type, sub_info: TypeInfo,

mypy/constraints.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mypy.types import (
66
CallableType, Type, TypeVisitor, UnboundType, AnyType, Void, NoneTyp, TypeVarType,
77
Instance, TupleType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
8-
UninhabitedType, TypeType, is_named_instance
8+
UninhabitedType, TypeType, TypeVarId, is_named_instance
99
)
1010
from mypy.maptype import map_instance_to_supertype
1111
from mypy import nodes
@@ -23,11 +23,11 @@ class Constraint:
2323
It can be either T <: type or T :> type (T is a type variable).
2424
"""
2525

26-
type_var = 0 # Type variable id
27-
op = 0 # SUBTYPE_OF or SUPERTYPE_OF
28-
target = None # type: Type
26+
type_var = None # Type variable id
27+
op = 0 # SUBTYPE_OF or SUPERTYPE_OF
28+
target = None # type: Type
2929

30-
def __init__(self, type_var: int, op: int, target: Type) -> None:
30+
def __init__(self, type_var: TypeVarId, op: int, target: Type) -> None:
3131
self.type_var = type_var
3232
self.op = op
3333
self.target = target

mypy/erasetype.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Optional, Container
1+
from typing import Optional, Container, Callable
22

33
from mypy.types import (
4-
Type, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp,
4+
Type, TypeVisitor, UnboundType, ErrorType, AnyType, Void, NoneTyp, TypeVarId,
55
Instance, TypeVarType, CallableType, TupleType, UnionType, Overloaded, ErasedType,
66
PartialType, DeletedType, TypeTranslator, TypeList, UninhabitedType, TypeType
77
)
@@ -105,20 +105,30 @@ def visit_instance(self, t: Instance) -> Type:
105105
return Instance(t.type, [], t.line)
106106

107107

108-
def erase_typevars(t: Type, ids_to_erase: Optional[Container[int]] = None) -> Type:
108+
def erase_typevars(t: Type, ids_to_erase: Optional[Container[TypeVarId]] = None) -> Type:
109109
"""Replace all type variables in a type with any,
110110
or just the ones in the provided collection.
111111
"""
112-
return t.accept(TypeVarEraser(ids_to_erase))
112+
def erase_id(id: TypeVarId) -> bool:
113+
if ids_to_erase is None:
114+
return True
115+
return id in ids_to_erase
116+
return t.accept(TypeVarEraser(erase_id, AnyType()))
117+
118+
119+
def replace_meta_vars(t: Type, target_type: Type) -> Type:
120+
"""Replace unification variables in a type with the target type."""
121+
return t.accept(TypeVarEraser(lambda id: id.is_meta_var(), target_type))
113122

114123

115124
class TypeVarEraser(TypeTranslator):
116125
"""Implementation of type erasure"""
117126

118-
def __init__(self, ids_to_erase: Optional[Container[int]]) -> None:
119-
self.ids_to_erase = ids_to_erase
127+
def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None:
128+
self.erase_id = erase_id
129+
self.replacement = replacement
120130

121131
def visit_type_var(self, t: TypeVarType) -> Type:
122-
if self.ids_to_erase is not None and t.id not in self.ids_to_erase:
123-
return t
124-
return AnyType()
132+
if self.erase_id(t.id):
133+
return self.replacement
134+
return t

0 commit comments

Comments
 (0)