Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of Any/object in variadic generics #18643

Merged
merged 5 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def visit_tuple_type(self, t: TupleType) -> Type:
return unpacked
return result

def visit_callable_type(self, t: CallableType) -> Type:
result = super().visit_callable_type(t)
assert isinstance(result, ProperType) and isinstance(result, CallableType)
# Usually this is done in semanal_typeargs.py, but erasure can create
# a non-normal callable from normal one.
result.normalize_trivial_unpack()
return result

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
if self.erase_id(t.id):
return t.tuple_fallback.copy_modified(args=[self.replacement])
Expand Down
12 changes: 8 additions & 4 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def visit_instance(self, t: Instance) -> Type:
if isinstance(arg, UnpackType):
unpacked = get_proper_type(arg.type)
if isinstance(unpacked, Instance):
# TODO: this and similar asserts below may be unsafe because get_proper_type()
# may be called during semantic analysis before all invalid types are removed.
assert unpacked.type.fullname == "builtins.tuple"
args = list(unpacked.args)
return t.copy_modified(args=args)
Expand Down Expand Up @@ -333,10 +335,7 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l

var_arg_type = get_proper_type(var_arg.type)
new_unpack: Type
if isinstance(var_arg_type, Instance):
# we have something like Unpack[Tuple[Any, ...]]
new_unpack = UnpackType(var_arg.type.accept(self))
elif isinstance(var_arg_type, TupleType):
if isinstance(var_arg_type, TupleType):
# We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
expanded_tuple = var_arg_type.accept(self)
assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
Expand All @@ -348,6 +347,11 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l
fallback = var_arg_type.tuple_fallback
expanded_items = self.expand_unpack(var_arg)
new_unpack = UnpackType(TupleType(expanded_items, fallback))
# Since get_proper_type() may be called in semanal.py before callable
# normalization happens, we need to also handle non-normal cases here.
elif isinstance(var_arg_type, Instance):
# we have something like Unpack[Tuple[Any, ...]]
new_unpack = UnpackType(var_arg.type.accept(self))
else:
# We have invalid type in Unpack. This can happen when expanding aliases
# to Callable[[*Invalid], Ret]
Expand Down
9 changes: 9 additions & 0 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType:
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
if self.s == t:
return t
if isinstance(self.s, Instance) and is_subtype(t.upper_bound, self.s):
# TODO: should we do this more generally and for all TypeVarLikeTypes?
return self.s
return self.default(self.s)

def visit_unpack_type(self, t: UnpackType) -> UnpackType:
Expand Down Expand Up @@ -350,6 +353,8 @@ def visit_instance(self, t: Instance) -> ProperType:
return join_types(t, self.s)
elif isinstance(self.s, LiteralType):
return join_types(t, self.s)
elif isinstance(self.s, TypeVarTupleType) and is_subtype(self.s.upper_bound, t):
return t
else:
return self.default(self.s)

Expand Down Expand Up @@ -562,6 +567,10 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
assert isinstance(fallback, Instance)
items = self.join_tuples(self.s, t)
if items is not None:
if len(items) == 1 and isinstance(item := items[0], UnpackType):
if isinstance(unpacked := get_proper_type(item.type), Instance):
# Avoid double-wrapping tuple[*tuple[X, ...]]
return unpacked
return TupleType(items, fallback)
else:
# TODO: should this be a default fallback behaviour like for meet?
Expand Down
47 changes: 36 additions & 11 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Iterator
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from typing import Any, Callable, Final, TypeVar, cast
from typing_extensions import TypeAlias as _TypeAlias
Expand Down Expand Up @@ -414,6 +414,9 @@ def _is_subtype(self, left: Type, right: Type) -> bool:
return is_proper_subtype(left, right, subtype_context=self.subtype_context)
return is_subtype(left, right, subtype_context=self.subtype_context)

def _all_subtypes(self, lefts: Iterable[Type], rights: Iterable[Type]) -> bool:
return all(self._is_subtype(li, ri) for (li, ri) in zip(lefts, rights))

# visit_x(left) means: is left (which is an instance of X) a subtype of right?

def visit_unbound_type(self, left: UnboundType) -> bool:
Expand Down Expand Up @@ -856,11 +859,25 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
# There are some items on the left that will never have a matching length
# on the right.
return False
left_prefix = left_unpack_index
left_suffix = len(left.items) - left_prefix - 1
left_unpack = left.items[left_unpack_index]
assert isinstance(left_unpack, UnpackType)
left_unpacked = get_proper_type(left_unpack.type)
if not isinstance(left_unpacked, Instance):
# *Ts unpacks can't be split.
# *Ts unpack can't be split, except if it is all mapped to Anys or objects.
if self.is_top_type(right_item):
right_prefix_types, middle, right_suffix_types = split_with_prefix_and_suffix(
tuple(right.items), left_prefix, left_suffix
)
if not all(
self.is_top_type(ri) or isinstance(ri, UnpackType) for ri in middle
):
return False
# Also check the tails match as well.
return self._all_subtypes(
left.items[:left_prefix], right_prefix_types
) and self._all_subtypes(left.items[-left_suffix:], right_suffix_types)
return False
assert left_unpacked.type.fullname == "builtins.tuple"
left_item = left_unpacked.args[0]
Expand All @@ -871,8 +888,6 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
# and then check subtyping for all finite overlaps.
if not self._is_subtype(left_item, right_item):
return False
left_prefix = left_unpack_index
left_suffix = len(left.items) - left_prefix - 1
max_overlap = max(0, right_prefix - left_prefix, right_suffix - left_suffix)
for overlap in range(max_overlap + 1):
repr_items = left.items[:left_prefix] + [left_item] * overlap
Expand All @@ -883,6 +898,11 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
return False
return True

def is_top_type(self, typ: Type) -> bool:
if not self.proper_subtype and isinstance(get_proper_type(typ), AnyType):
return True
return is_named_instance(typ, "builtins.object")

def visit_typeddict_type(self, left: TypedDictType) -> bool:
right = self.right
if isinstance(right, Instance):
Expand Down Expand Up @@ -1653,17 +1673,18 @@ def are_parameters_compatible(
return True
trivial_suffix = is_trivial_suffix(right) and not is_proper_subtype

trivial_vararg_suffix = False
if (
right.arg_kinds == [ARG_STAR]
and isinstance(get_proper_type(right.arg_types[0]), AnyType)
right.arg_kinds[-1:] == [ARG_STAR]
and isinstance(get_proper_type(right.arg_types[-1]), AnyType)
and not is_proper_subtype
and all(k.is_positional(star=True) for k in left.arg_kinds)
):
# Similar to how (*Any, **Any) is considered a supertype of all callables, we consider
# (*Any) a supertype of all callables with positional arguments. This is needed in
# particular because we often refuse to try type inference if actual type is not
# a subtype of erased template type.
if all(k.is_positional() for k in left.arg_kinds) and ignore_pos_arg_names:
return True
trivial_vararg_suffix = True

# Match up corresponding arguments and check them for compatibility. In
# every pair (argL, argR) of corresponding arguments from L and R, argL must
Expand Down Expand Up @@ -1697,7 +1718,11 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
return not allow_partial_overlap and not trivial_suffix
return not is_compat(right_arg.typ, left_arg.typ)

if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
if (
_incompatible(left_star, right_star)
and not trivial_vararg_suffix
or _incompatible(left_star2, right_star2)
):
return False

# Phase 1b: Check non-star args: for every arg right can accept, left must
Expand Down Expand Up @@ -1727,8 +1752,8 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# Phase 1c: Check var args. Right has an infinite series of optional positional
# arguments. Get all further positional args of left, and make sure
# they're more general than the corresponding member in right.
# TODO: are we handling UnpackType correctly here?
if right_star is not None:
# TODO: handle suffix in UnpackType (i.e. *args: *Tuple[Ts, X, Y]).
if right_star is not None and not trivial_vararg_suffix:
# Synthesize an anonymous formal argument for the right
right_by_position = right.try_synthesizing_arg_from_vararg(None)
assert right_by_position is not None
Expand Down
11 changes: 6 additions & 5 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def test_variadic_tuple_joins(self) -> None:
self.assert_join(
self.tuple(self.fx.a, self.fx.a),
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
Instance(self.fx.std_tuplei, [self.fx.a]),
)
self.assert_join(
self.tuple(self.fx.a, self.fx.a),
Expand Down Expand Up @@ -1049,12 +1049,12 @@ def test_variadic_tuple_joins(self) -> None:
self.tuple(
self.fx.a, UnpackType(Instance(self.fx.std_tuplei, [self.fx.a])), self.fx.a
),
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
Instance(self.fx.std_tuplei, [self.fx.a]),
)
self.assert_join(
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))),
Instance(self.fx.std_tuplei, [self.fx.a]),
)
self.assert_join(
self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a])), self.fx.a),
Expand Down Expand Up @@ -1584,11 +1584,12 @@ def make_call(*items: tuple[str, str | None]) -> CallExpr:
class TestExpandTypeLimitGetProperType(TestCase):
# WARNING: do not increase this number unless absolutely necessary,
# and you understand what you are doing.
ALLOWED_GET_PROPER_TYPES = 9
ALLOWED_GET_PROPER_TYPES = 7

@skipUnless(mypy.expandtype.__file__.endswith(".py"), "Skip for compiled mypy")
def test_count_get_proper_type(self) -> None:
with open(mypy.expandtype.__file__) as f:
code = f.read()
get_proper_type_count = len(re.findall("get_proper_type", code))
get_proper_type_count = len(re.findall(r"get_proper_type\(", code))
get_proper_type_count -= len(re.findall(r"get_proper_type\(\)", code))
assert get_proper_type_count == self.ALLOWED_GET_PROPER_TYPES
105 changes: 97 additions & 8 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -2305,18 +2305,21 @@ def higher_order(f: _CallableValue) -> None: ...
def good1(*args: int) -> None: ...
def good2(*args: str) -> int: ...

def bad1(a: str, b: int, /) -> None: ...
def bad2(c: bytes, *args: int) -> str: ...
def bad3(*, d: str) -> int: ...
def bad4(**kwargs: None) -> None: ...
# These are special-cased for *args: Any (as opposite to *args: object)
def ok1(a: str, b: int, /) -> None: ...
def ok2(c: bytes, *args: int) -> str: ...

def bad1(*, d: str) -> int: ...
def bad2(**kwargs: None) -> None: ...

higher_order(good1)
higher_order(good2)

higher_order(bad1) # E: Argument 1 to "higher_order" has incompatible type "Callable[[str, int], None]"; expected "Callable[[VarArg(Any)], Any]"
higher_order(bad2) # E: Argument 1 to "higher_order" has incompatible type "Callable[[bytes, VarArg(int)], str]"; expected "Callable[[VarArg(Any)], Any]"
higher_order(bad3) # E: Argument 1 to "higher_order" has incompatible type "Callable[[NamedArg(str, 'd')], int]"; expected "Callable[[VarArg(Any)], Any]"
higher_order(bad4) # E: Argument 1 to "higher_order" has incompatible type "Callable[[KwArg(None)], None]"; expected "Callable[[VarArg(Any)], Any]"
higher_order(ok1)
higher_order(ok2)

higher_order(bad1) # E: Argument 1 to "higher_order" has incompatible type "Callable[[NamedArg(str, 'd')], int]"; expected "Callable[[VarArg(Any)], Any]"
higher_order(bad2) # E: Argument 1 to "higher_order" has incompatible type "Callable[[KwArg(None)], None]"; expected "Callable[[VarArg(Any)], Any]"
[builtins fixtures/tuple.pyi]

[case testAliasToCallableWithUnpack2]
Expand Down Expand Up @@ -2513,3 +2516,89 @@ x4: Foo[Unpack[tuple[str, ...]]]
y4: Foo[Unpack[tuple[int, int]]]
x4 is y4 # E: Non-overlapping identity check (left operand type: "Foo[Unpack[Tuple[str, ...]]]", right operand type: "Foo[int, int]")
[builtins fixtures/tuple.pyi]

[case testTypeVarTupleErasureNormalized]
from typing import TypeVarTuple, Unpack, Generic, Union
from collections.abc import Callable

Args = TypeVarTuple("Args")

class Built(Generic[Unpack[Args]]):
pass

def example(
fn: Union[Built[Unpack[Args]], Callable[[Unpack[Args]], None]]
) -> Built[Unpack[Args]]: ...

@example
def command() -> None:
return
reveal_type(command) # N: Revealed type is "__main__.Built[()]"
[builtins fixtures/tuple.pyi]

[case testTypeVarTupleSelfMappedPrefix]
from typing import TypeVarTuple, Generic, Unpack

Ts = TypeVarTuple("Ts")
class Base(Generic[Unpack[Ts]]):
attr: tuple[Unpack[Ts]]

@property
def prop(self) -> tuple[Unpack[Ts]]:
return self.attr

def meth(self) -> tuple[Unpack[Ts]]:
return self.attr

Ss = TypeVarTuple("Ss")
class Derived(Base[str, Unpack[Ss]]):
def test(self) -> None:
reveal_type(self.attr) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]"
reveal_type(self.prop) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]"
reveal_type(self.meth()) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]"
[builtins fixtures/property.pyi]

[case testTypeVarTupleProtocolPrefix]
from typing import Protocol, Unpack, TypeVarTuple

Ts = TypeVarTuple("Ts")
class A(Protocol[Unpack[Ts]]):
def f(self, z: str, *args: Unpack[Ts]) -> None: ...

class C:
def f(self, z: str, x: int) -> None: ...

def f(x: A[Unpack[Ts]]) -> tuple[Unpack[Ts]]: ...

reveal_type(f(C())) # N: Revealed type is "Tuple[builtins.int]"
[builtins fixtures/tuple.pyi]

[case testTypeVarTupleHomogeneousCallableNormalized]
from typing import Generic, Unpack, TypeVarTuple

Ts = TypeVarTuple("Ts")
class C(Generic[Unpack[Ts]]):
def foo(self, *args: Unpack[Ts]) -> None: ...

c: C[Unpack[tuple[int, ...]]]
reveal_type(c.foo) # N: Revealed type is "def (*args: builtins.int)"
[builtins fixtures/tuple.pyi]

[case testTypeVarTupleJoinInstanceTypeVar]
from typing import Any, Unpack, TypeVarTuple, TypeVar

T = TypeVar("T")
Ts = TypeVarTuple("Ts")

def join(x: T, y: T) -> T: ...
def test(xs: tuple[Unpack[Ts]], xsi: tuple[int, Unpack[Ts]]) -> None:
a: tuple[Any, ...]
reveal_type(join(xs, a)) # N: Revealed type is "builtins.tuple[Any, ...]"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also try the opposite argument order (join(a, xs)) here and below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, added symmetric check for each pair.

reveal_type(join(a, xs)) # N: Revealed type is "builtins.tuple[Any, ...]"
aa: tuple[Unpack[tuple[Any, ...]]]
reveal_type(join(xs, aa)) # N: Revealed type is "builtins.tuple[Any, ...]"
reveal_type(join(aa, xs)) # N: Revealed type is "builtins.tuple[Any, ...]"
ai: tuple[int, Unpack[tuple[Any, ...]]]
reveal_type(join(xsi, ai)) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]"
reveal_type(join(ai, xsi)) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]"
[builtins fixtures/tuple.pyi]