Skip to content

Commit a9ce2cb

Browse files
ilevkivskyiericmarkmartin
authored andcommitted
Better handling of Any/object in variadic generics (python#18643)
Fixes python#18407 Fixes python#17184 Fixes python#16567 There are three things here: * Allow erased variadic callables with non-empty prefix to be supertypes of the non-erased ones. This relaxes a bit callable subtyping in general, but IMO this makes sense, people who want to be strict should simply use `*args: object` instead. An alternative would be to track erased variadic callables explicitly, which is ugly and fragile. * Add important missing case in `subtypes.py` for `*Ts` w.r.t. `Any`/`object` that handles similar situations for variadic instances and tuples (here however there is nothing special about `Any` vs `object`). * I also fix inconsistency in join uncovered by the above two. The changes in `expandtype.py` are no-op, I just noticed potential danger while playing with this, so wanted to highlight it with comments for the future.
1 parent 191ec80 commit a9ce2cb

File tree

6 files changed

+164
-28
lines changed

6 files changed

+164
-28
lines changed

mypy/erasetype.py

+8
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ def visit_tuple_type(self, t: TupleType) -> Type:
203203
return unpacked
204204
return result
205205

206+
def visit_callable_type(self, t: CallableType) -> Type:
207+
result = super().visit_callable_type(t)
208+
assert isinstance(result, ProperType) and isinstance(result, CallableType)
209+
# Usually this is done in semanal_typeargs.py, but erasure can create
210+
# a non-normal callable from normal one.
211+
result.normalize_trivial_unpack()
212+
return result
213+
206214
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
207215
if self.erase_id(t.id):
208216
return t.tuple_fallback.copy_modified(args=[self.replacement])

mypy/expandtype.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def visit_instance(self, t: Instance) -> Type:
226226
if isinstance(arg, UnpackType):
227227
unpacked = get_proper_type(arg.type)
228228
if isinstance(unpacked, Instance):
229+
# TODO: this and similar asserts below may be unsafe because get_proper_type()
230+
# may be called during semantic analysis before all invalid types are removed.
229231
assert unpacked.type.fullname == "builtins.tuple"
230232
args = list(unpacked.args)
231233
return t.copy_modified(args=args)
@@ -333,10 +335,7 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l
333335

334336
var_arg_type = get_proper_type(var_arg.type)
335337
new_unpack: Type
336-
if isinstance(var_arg_type, Instance):
337-
# we have something like Unpack[Tuple[Any, ...]]
338-
new_unpack = UnpackType(var_arg.type.accept(self))
339-
elif isinstance(var_arg_type, TupleType):
338+
if isinstance(var_arg_type, TupleType):
340339
# We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
341340
expanded_tuple = var_arg_type.accept(self)
342341
assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
@@ -348,6 +347,11 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l
348347
fallback = var_arg_type.tuple_fallback
349348
expanded_items = self.expand_unpack(var_arg)
350349
new_unpack = UnpackType(TupleType(expanded_items, fallback))
350+
# Since get_proper_type() may be called in semanal.py before callable
351+
# normalization happens, we need to also handle non-normal cases here.
352+
elif isinstance(var_arg_type, Instance):
353+
# we have something like Unpack[Tuple[Any, ...]]
354+
new_unpack = UnpackType(var_arg.type.accept(self))
351355
else:
352356
# We have invalid type in Unpack. This can happen when expanding aliases
353357
# to Callable[[*Invalid], Ret]

mypy/join.py

+9
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType:
299299
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
300300
if self.s == t:
301301
return t
302+
if isinstance(self.s, Instance) and is_subtype(t.upper_bound, self.s):
303+
# TODO: should we do this more generally and for all TypeVarLikeTypes?
304+
return self.s
302305
return self.default(self.s)
303306

304307
def visit_unpack_type(self, t: UnpackType) -> UnpackType:
@@ -350,6 +353,8 @@ def visit_instance(self, t: Instance) -> ProperType:
350353
return join_types(t, self.s)
351354
elif isinstance(self.s, LiteralType):
352355
return join_types(t, self.s)
356+
elif isinstance(self.s, TypeVarTupleType) and is_subtype(self.s.upper_bound, t):
357+
return t
353358
else:
354359
return self.default(self.s)
355360

@@ -562,6 +567,10 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
562567
assert isinstance(fallback, Instance)
563568
items = self.join_tuples(self.s, t)
564569
if items is not None:
570+
if len(items) == 1 and isinstance(item := items[0], UnpackType):
571+
if isinstance(unpacked := get_proper_type(item.type), Instance):
572+
# Avoid double-wrapping tuple[*tuple[X, ...]]
573+
return unpacked
565574
return TupleType(items, fallback)
566575
else:
567576
# TODO: should this be a default fallback behaviour like for meet?

mypy/subtypes.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Iterator
3+
from collections.abc import Iterable, Iterator
44
from contextlib import contextmanager
55
from typing import Any, Callable, Final, TypeVar, cast
66
from typing_extensions import TypeAlias as _TypeAlias
@@ -414,6 +414,9 @@ def _is_subtype(self, left: Type, right: Type) -> bool:
414414
return is_proper_subtype(left, right, subtype_context=self.subtype_context)
415415
return is_subtype(left, right, subtype_context=self.subtype_context)
416416

417+
def _all_subtypes(self, lefts: Iterable[Type], rights: Iterable[Type]) -> bool:
418+
return all(self._is_subtype(li, ri) for (li, ri) in zip(lefts, rights))
419+
417420
# visit_x(left) means: is left (which is an instance of X) a subtype of right?
418421

419422
def visit_unbound_type(self, left: UnboundType) -> bool:
@@ -856,11 +859,25 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
856859
# There are some items on the left that will never have a matching length
857860
# on the right.
858861
return False
862+
left_prefix = left_unpack_index
863+
left_suffix = len(left.items) - left_prefix - 1
859864
left_unpack = left.items[left_unpack_index]
860865
assert isinstance(left_unpack, UnpackType)
861866
left_unpacked = get_proper_type(left_unpack.type)
862867
if not isinstance(left_unpacked, Instance):
863-
# *Ts unpacks can't be split.
868+
# *Ts unpack can't be split, except if it is all mapped to Anys or objects.
869+
if self.is_top_type(right_item):
870+
right_prefix_types, middle, right_suffix_types = split_with_prefix_and_suffix(
871+
tuple(right.items), left_prefix, left_suffix
872+
)
873+
if not all(
874+
self.is_top_type(ri) or isinstance(ri, UnpackType) for ri in middle
875+
):
876+
return False
877+
# Also check the tails match as well.
878+
return self._all_subtypes(
879+
left.items[:left_prefix], right_prefix_types
880+
) and self._all_subtypes(left.items[-left_suffix:], right_suffix_types)
864881
return False
865882
assert left_unpacked.type.fullname == "builtins.tuple"
866883
left_item = left_unpacked.args[0]
@@ -871,8 +888,6 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
871888
# and then check subtyping for all finite overlaps.
872889
if not self._is_subtype(left_item, right_item):
873890
return False
874-
left_prefix = left_unpack_index
875-
left_suffix = len(left.items) - left_prefix - 1
876891
max_overlap = max(0, right_prefix - left_prefix, right_suffix - left_suffix)
877892
for overlap in range(max_overlap + 1):
878893
repr_items = left.items[:left_prefix] + [left_item] * overlap
@@ -883,6 +898,11 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
883898
return False
884899
return True
885900

901+
def is_top_type(self, typ: Type) -> bool:
902+
if not self.proper_subtype and isinstance(get_proper_type(typ), AnyType):
903+
return True
904+
return is_named_instance(typ, "builtins.object")
905+
886906
def visit_typeddict_type(self, left: TypedDictType) -> bool:
887907
right = self.right
888908
if isinstance(right, Instance):
@@ -1653,17 +1673,18 @@ def are_parameters_compatible(
16531673
return True
16541674
trivial_suffix = is_trivial_suffix(right) and not is_proper_subtype
16551675

1676+
trivial_vararg_suffix = False
16561677
if (
1657-
right.arg_kinds == [ARG_STAR]
1658-
and isinstance(get_proper_type(right.arg_types[0]), AnyType)
1678+
right.arg_kinds[-1:] == [ARG_STAR]
1679+
and isinstance(get_proper_type(right.arg_types[-1]), AnyType)
16591680
and not is_proper_subtype
1681+
and all(k.is_positional(star=True) for k in left.arg_kinds)
16601682
):
16611683
# Similar to how (*Any, **Any) is considered a supertype of all callables, we consider
16621684
# (*Any) a supertype of all callables with positional arguments. This is needed in
16631685
# particular because we often refuse to try type inference if actual type is not
16641686
# a subtype of erased template type.
1665-
if all(k.is_positional() for k in left.arg_kinds) and ignore_pos_arg_names:
1666-
return True
1687+
trivial_vararg_suffix = True
16671688

16681689
# Match up corresponding arguments and check them for compatibility. In
16691690
# every pair (argL, argR) of corresponding arguments from L and R, argL must
@@ -1697,7 +1718,11 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
16971718
return not allow_partial_overlap and not trivial_suffix
16981719
return not is_compat(right_arg.typ, left_arg.typ)
16991720

1700-
if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
1721+
if (
1722+
_incompatible(left_star, right_star)
1723+
and not trivial_vararg_suffix
1724+
or _incompatible(left_star2, right_star2)
1725+
):
17011726
return False
17021727

17031728
# Phase 1b: Check non-star args: for every arg right can accept, left must
@@ -1727,8 +1752,8 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
17271752
# Phase 1c: Check var args. Right has an infinite series of optional positional
17281753
# arguments. Get all further positional args of left, and make sure
17291754
# they're more general than the corresponding member in right.
1730-
# TODO: are we handling UnpackType correctly here?
1731-
if right_star is not None:
1755+
# TODO: handle suffix in UnpackType (i.e. *args: *Tuple[Ts, X, Y]).
1756+
if right_star is not None and not trivial_vararg_suffix:
17321757
# Synthesize an anonymous formal argument for the right
17331758
right_by_position = right.try_synthesizing_arg_from_vararg(None)
17341759
assert right_by_position is not None

mypy/test/testtypes.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,7 @@ def test_variadic_tuple_joins(self) -> None:
10211021
self.assert_join(
10221022
self.tuple(self.fx.a, self.fx.a),
10231023
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
1024-
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
1024+
Instance(self.fx.std_tuplei, [self.fx.a]),
10251025
)
10261026
self.assert_join(
10271027
self.tuple(self.fx.a, self.fx.a),
@@ -1049,12 +1049,12 @@ def test_variadic_tuple_joins(self) -> None:
10491049
self.tuple(
10501050
self.fx.a, UnpackType(Instance(self.fx.std_tuplei, [self.fx.a])), self.fx.a
10511051
),
1052-
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
1052+
Instance(self.fx.std_tuplei, [self.fx.a]),
10531053
)
10541054
self.assert_join(
10551055
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
10561056
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
1057-
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
1057+
Instance(self.fx.std_tuplei, [self.fx.a]),
10581058
)
10591059
self.assert_join(
10601060
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a])), self.fx.a),
@@ -1584,11 +1584,12 @@ def make_call(*items: tuple[str, str | None]) -> CallExpr:
15841584
class TestExpandTypeLimitGetProperType(TestCase):
15851585
# WARNING: do not increase this number unless absolutely necessary,
15861586
# and you understand what you are doing.
1587-
ALLOWED_GET_PROPER_TYPES = 9
1587+
ALLOWED_GET_PROPER_TYPES = 7
15881588

15891589
@skipUnless(mypy.expandtype.__file__.endswith(".py"), "Skip for compiled mypy")
15901590
def test_count_get_proper_type(self) -> None:
15911591
with open(mypy.expandtype.__file__) as f:
15921592
code = f.read()
1593-
get_proper_type_count = len(re.findall("get_proper_type", code))
1593+
get_proper_type_count = len(re.findall(r"get_proper_type\(", code))
1594+
get_proper_type_count -= len(re.findall(r"get_proper_type\(\)", code))
15941595
assert get_proper_type_count == self.ALLOWED_GET_PROPER_TYPES

test-data/unit/check-typevar-tuple.test

+97-8
Original file line numberDiff line numberDiff line change
@@ -2309,18 +2309,21 @@ def higher_order(f: _CallableValue) -> None: ...
23092309
def good1(*args: int) -> None: ...
23102310
def good2(*args: str) -> int: ...
23112311

2312-
def bad1(a: str, b: int, /) -> None: ...
2313-
def bad2(c: bytes, *args: int) -> str: ...
2314-
def bad3(*, d: str) -> int: ...
2315-
def bad4(**kwargs: None) -> None: ...
2312+
# These are special-cased for *args: Any (as opposite to *args: object)
2313+
def ok1(a: str, b: int, /) -> None: ...
2314+
def ok2(c: bytes, *args: int) -> str: ...
2315+
2316+
def bad1(*, d: str) -> int: ...
2317+
def bad2(**kwargs: None) -> None: ...
23162318

23172319
higher_order(good1)
23182320
higher_order(good2)
23192321

2320-
higher_order(bad1) # E: Argument 1 to "higher_order" has incompatible type "Callable[[str, int], None]"; expected "Callable[[VarArg(Any)], Any]"
2321-
higher_order(bad2) # E: Argument 1 to "higher_order" has incompatible type "Callable[[bytes, VarArg(int)], str]"; expected "Callable[[VarArg(Any)], Any]"
2322-
higher_order(bad3) # E: Argument 1 to "higher_order" has incompatible type "Callable[[NamedArg(str, 'd')], int]"; expected "Callable[[VarArg(Any)], Any]"
2323-
higher_order(bad4) # E: Argument 1 to "higher_order" has incompatible type "Callable[[KwArg(None)], None]"; expected "Callable[[VarArg(Any)], Any]"
2322+
higher_order(ok1)
2323+
higher_order(ok2)
2324+
2325+
higher_order(bad1) # E: Argument 1 to "higher_order" has incompatible type "Callable[[NamedArg(str, 'd')], int]"; expected "Callable[[VarArg(Any)], Any]"
2326+
higher_order(bad2) # E: Argument 1 to "higher_order" has incompatible type "Callable[[KwArg(None)], None]"; expected "Callable[[VarArg(Any)], Any]"
23242327
[builtins fixtures/tuple.pyi]
23252328

23262329
[case testAliasToCallableWithUnpack2]
@@ -2517,3 +2520,89 @@ x4: Foo[Unpack[tuple[str, ...]]]
25172520
y4: Foo[Unpack[tuple[int, int]]]
25182521
x4 is y4 # E: Non-overlapping identity check (left operand type: "Foo[Unpack[Tuple[str, ...]]]", right operand type: "Foo[int, int]")
25192522
[builtins fixtures/tuple.pyi]
2523+
2524+
[case testTypeVarTupleErasureNormalized]
2525+
from typing import TypeVarTuple, Unpack, Generic, Union
2526+
from collections.abc import Callable
2527+
2528+
Args = TypeVarTuple("Args")
2529+
2530+
class Built(Generic[Unpack[Args]]):
2531+
pass
2532+
2533+
def example(
2534+
fn: Union[Built[Unpack[Args]], Callable[[Unpack[Args]], None]]
2535+
) -> Built[Unpack[Args]]: ...
2536+
2537+
@example
2538+
def command() -> None:
2539+
return
2540+
reveal_type(command) # N: Revealed type is "__main__.Built[()]"
2541+
[builtins fixtures/tuple.pyi]
2542+
2543+
[case testTypeVarTupleSelfMappedPrefix]
2544+
from typing import TypeVarTuple, Generic, Unpack
2545+
2546+
Ts = TypeVarTuple("Ts")
2547+
class Base(Generic[Unpack[Ts]]):
2548+
attr: tuple[Unpack[Ts]]
2549+
2550+
@property
2551+
def prop(self) -> tuple[Unpack[Ts]]:
2552+
return self.attr
2553+
2554+
def meth(self) -> tuple[Unpack[Ts]]:
2555+
return self.attr
2556+
2557+
Ss = TypeVarTuple("Ss")
2558+
class Derived(Base[str, Unpack[Ss]]):
2559+
def test(self) -> None:
2560+
reveal_type(self.attr) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]"
2561+
reveal_type(self.prop) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]"
2562+
reveal_type(self.meth()) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]"
2563+
[builtins fixtures/property.pyi]
2564+
2565+
[case testTypeVarTupleProtocolPrefix]
2566+
from typing import Protocol, Unpack, TypeVarTuple
2567+
2568+
Ts = TypeVarTuple("Ts")
2569+
class A(Protocol[Unpack[Ts]]):
2570+
def f(self, z: str, *args: Unpack[Ts]) -> None: ...
2571+
2572+
class C:
2573+
def f(self, z: str, x: int) -> None: ...
2574+
2575+
def f(x: A[Unpack[Ts]]) -> tuple[Unpack[Ts]]: ...
2576+
2577+
reveal_type(f(C())) # N: Revealed type is "Tuple[builtins.int]"
2578+
[builtins fixtures/tuple.pyi]
2579+
2580+
[case testTypeVarTupleHomogeneousCallableNormalized]
2581+
from typing import Generic, Unpack, TypeVarTuple
2582+
2583+
Ts = TypeVarTuple("Ts")
2584+
class C(Generic[Unpack[Ts]]):
2585+
def foo(self, *args: Unpack[Ts]) -> None: ...
2586+
2587+
c: C[Unpack[tuple[int, ...]]]
2588+
reveal_type(c.foo) # N: Revealed type is "def (*args: builtins.int)"
2589+
[builtins fixtures/tuple.pyi]
2590+
2591+
[case testTypeVarTupleJoinInstanceTypeVar]
2592+
from typing import Any, Unpack, TypeVarTuple, TypeVar
2593+
2594+
T = TypeVar("T")
2595+
Ts = TypeVarTuple("Ts")
2596+
2597+
def join(x: T, y: T) -> T: ...
2598+
def test(xs: tuple[Unpack[Ts]], xsi: tuple[int, Unpack[Ts]]) -> None:
2599+
a: tuple[Any, ...]
2600+
reveal_type(join(xs, a)) # N: Revealed type is "builtins.tuple[Any, ...]"
2601+
reveal_type(join(a, xs)) # N: Revealed type is "builtins.tuple[Any, ...]"
2602+
aa: tuple[Unpack[tuple[Any, ...]]]
2603+
reveal_type(join(xs, aa)) # N: Revealed type is "builtins.tuple[Any, ...]"
2604+
reveal_type(join(aa, xs)) # N: Revealed type is "builtins.tuple[Any, ...]"
2605+
ai: tuple[int, Unpack[tuple[Any, ...]]]
2606+
reveal_type(join(xsi, ai)) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]"
2607+
reveal_type(join(ai, xsi)) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]"
2608+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)