diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 0e6a8bf8a829..6c47670d6687 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -203,6 +203,14 @@ def visit_tuple_type(self, t: TupleType) -> Type: return unpacked return result + def visit_callable_type(self, t: CallableType) -> Type: + result = super().visit_callable_type(t) + assert isinstance(result, ProperType) and isinstance(result, CallableType) + # Usually this is done in semanal_typeargs.py, but erasure can create + # a non-normal callable from normal one. + result.normalize_trivial_unpack() + return result + def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: if self.erase_id(t.id): return t.tuple_fallback.copy_modified(args=[self.replacement]) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 8750da34d963..031f86e7dfff 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -226,6 +226,8 @@ def visit_instance(self, t: Instance) -> Type: if isinstance(arg, UnpackType): unpacked = get_proper_type(arg.type) if isinstance(unpacked, Instance): + # TODO: this and similar asserts below may be unsafe because get_proper_type() + # may be called during semantic analysis before all invalid types are removed. assert unpacked.type.fullname == "builtins.tuple" args = list(unpacked.args) return t.copy_modified(args=args) @@ -333,10 +335,7 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l var_arg_type = get_proper_type(var_arg.type) new_unpack: Type - if isinstance(var_arg_type, Instance): - # we have something like Unpack[Tuple[Any, ...]] - new_unpack = UnpackType(var_arg.type.accept(self)) - elif isinstance(var_arg_type, TupleType): + if isinstance(var_arg_type, TupleType): # We have something like Unpack[Tuple[Unpack[Ts], X1, X2]] expanded_tuple = var_arg_type.accept(self) assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType) @@ -348,6 +347,11 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l fallback = var_arg_type.tuple_fallback expanded_items = self.expand_unpack(var_arg) new_unpack = UnpackType(TupleType(expanded_items, fallback)) + # Since get_proper_type() may be called in semanal.py before callable + # normalization happens, we need to also handle non-normal cases here. + elif isinstance(var_arg_type, Instance): + # we have something like Unpack[Tuple[Any, ...]] + new_unpack = UnpackType(var_arg.type.accept(self)) else: # We have invalid type in Unpack. This can happen when expanding aliases # to Callable[[*Invalid], Ret] diff --git a/mypy/join.py b/mypy/join.py index a5c30b4b835d..ec60bf8f1520 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -299,6 +299,9 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType: def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType: if self.s == t: return t + if isinstance(self.s, Instance) and is_subtype(t.upper_bound, self.s): + # TODO: should we do this more generally and for all TypeVarLikeTypes? + return self.s return self.default(self.s) def visit_unpack_type(self, t: UnpackType) -> UnpackType: @@ -350,6 +353,8 @@ def visit_instance(self, t: Instance) -> ProperType: return join_types(t, self.s) elif isinstance(self.s, LiteralType): return join_types(t, self.s) + elif isinstance(self.s, TypeVarTupleType) and is_subtype(self.s.upper_bound, t): + return t else: return self.default(self.s) @@ -562,6 +567,10 @@ def visit_tuple_type(self, t: TupleType) -> ProperType: assert isinstance(fallback, Instance) items = self.join_tuples(self.s, t) if items is not None: + if len(items) == 1 and isinstance(item := items[0], UnpackType): + if isinstance(unpacked := get_proper_type(item.type), Instance): + # Avoid double-wrapping tuple[*tuple[X, ...]] + return unpacked return TupleType(items, fallback) else: # TODO: should this be a default fallback behaviour like for meet? diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 75cc7e25fde3..938be21201e9 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from contextlib import contextmanager from typing import Any, Callable, Final, TypeVar, cast from typing_extensions import TypeAlias as _TypeAlias @@ -414,6 +414,9 @@ def _is_subtype(self, left: Type, right: Type) -> bool: return is_proper_subtype(left, right, subtype_context=self.subtype_context) return is_subtype(left, right, subtype_context=self.subtype_context) + def _all_subtypes(self, lefts: Iterable[Type], rights: Iterable[Type]) -> bool: + return all(self._is_subtype(li, ri) for (li, ri) in zip(lefts, rights)) + # visit_x(left) means: is left (which is an instance of X) a subtype of right? def visit_unbound_type(self, left: UnboundType) -> bool: @@ -856,11 +859,25 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool: # There are some items on the left that will never have a matching length # on the right. return False + left_prefix = left_unpack_index + left_suffix = len(left.items) - left_prefix - 1 left_unpack = left.items[left_unpack_index] assert isinstance(left_unpack, UnpackType) left_unpacked = get_proper_type(left_unpack.type) if not isinstance(left_unpacked, Instance): - # *Ts unpacks can't be split. + # *Ts unpack can't be split, except if it is all mapped to Anys or objects. + if self.is_top_type(right_item): + right_prefix_types, middle, right_suffix_types = split_with_prefix_and_suffix( + tuple(right.items), left_prefix, left_suffix + ) + if not all( + self.is_top_type(ri) or isinstance(ri, UnpackType) for ri in middle + ): + return False + # Also check the tails match as well. + return self._all_subtypes( + left.items[:left_prefix], right_prefix_types + ) and self._all_subtypes(left.items[-left_suffix:], right_suffix_types) return False assert left_unpacked.type.fullname == "builtins.tuple" left_item = left_unpacked.args[0] @@ -871,8 +888,6 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool: # and then check subtyping for all finite overlaps. if not self._is_subtype(left_item, right_item): return False - left_prefix = left_unpack_index - left_suffix = len(left.items) - left_prefix - 1 max_overlap = max(0, right_prefix - left_prefix, right_suffix - left_suffix) for overlap in range(max_overlap + 1): repr_items = left.items[:left_prefix] + [left_item] * overlap @@ -883,6 +898,11 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool: return False return True + def is_top_type(self, typ: Type) -> bool: + if not self.proper_subtype and isinstance(get_proper_type(typ), AnyType): + return True + return is_named_instance(typ, "builtins.object") + def visit_typeddict_type(self, left: TypedDictType) -> bool: right = self.right if isinstance(right, Instance): @@ -1653,17 +1673,18 @@ def are_parameters_compatible( return True trivial_suffix = is_trivial_suffix(right) and not is_proper_subtype + trivial_vararg_suffix = False if ( - right.arg_kinds == [ARG_STAR] - and isinstance(get_proper_type(right.arg_types[0]), AnyType) + right.arg_kinds[-1:] == [ARG_STAR] + and isinstance(get_proper_type(right.arg_types[-1]), AnyType) and not is_proper_subtype + and all(k.is_positional(star=True) for k in left.arg_kinds) ): # Similar to how (*Any, **Any) is considered a supertype of all callables, we consider # (*Any) a supertype of all callables with positional arguments. This is needed in # particular because we often refuse to try type inference if actual type is not # a subtype of erased template type. - if all(k.is_positional() for k in left.arg_kinds) and ignore_pos_arg_names: - return True + trivial_vararg_suffix = True # Match up corresponding arguments and check them for compatibility. In # every pair (argL, argR) of corresponding arguments from L and R, argL must @@ -1697,7 +1718,11 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N return not allow_partial_overlap and not trivial_suffix return not is_compat(right_arg.typ, left_arg.typ) - if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2): + if ( + _incompatible(left_star, right_star) + and not trivial_vararg_suffix + or _incompatible(left_star2, right_star2) + ): return False # Phase 1b: Check non-star args: for every arg right can accept, left must @@ -1727,8 +1752,8 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N # Phase 1c: Check var args. Right has an infinite series of optional positional # arguments. Get all further positional args of left, and make sure # they're more general than the corresponding member in right. - # TODO: are we handling UnpackType correctly here? - if right_star is not None: + # TODO: handle suffix in UnpackType (i.e. *args: *Tuple[Ts, X, Y]). + if right_star is not None and not trivial_vararg_suffix: # Synthesize an anonymous formal argument for the right right_by_position = right.try_synthesizing_arg_from_vararg(None) assert right_by_position is not None diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 174441237ab4..a42519c64956 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -1021,7 +1021,7 @@ def test_variadic_tuple_joins(self) -> None: self.assert_join( self.tuple(self.fx.a, self.fx.a), self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))), - self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))), + Instance(self.fx.std_tuplei, [self.fx.a]), ) self.assert_join( self.tuple(self.fx.a, self.fx.a), @@ -1049,12 +1049,12 @@ def test_variadic_tuple_joins(self) -> None: self.tuple( self.fx.a, UnpackType(Instance(self.fx.std_tuplei, [self.fx.a])), self.fx.a ), - self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))), + Instance(self.fx.std_tuplei, [self.fx.a]), ) self.assert_join( self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))), self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))), - self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))), + Instance(self.fx.std_tuplei, [self.fx.a]), ) self.assert_join( self.tuple(UnpackType(Instance(self.fx.std_tuplei, [self.fx.a])), self.fx.a), @@ -1584,11 +1584,12 @@ def make_call(*items: tuple[str, str | None]) -> CallExpr: class TestExpandTypeLimitGetProperType(TestCase): # WARNING: do not increase this number unless absolutely necessary, # and you understand what you are doing. - ALLOWED_GET_PROPER_TYPES = 9 + ALLOWED_GET_PROPER_TYPES = 7 @skipUnless(mypy.expandtype.__file__.endswith(".py"), "Skip for compiled mypy") def test_count_get_proper_type(self) -> None: with open(mypy.expandtype.__file__) as f: code = f.read() - get_proper_type_count = len(re.findall("get_proper_type", code)) + get_proper_type_count = len(re.findall(r"get_proper_type\(", code)) + get_proper_type_count -= len(re.findall(r"get_proper_type\(\)", code)) assert get_proper_type_count == self.ALLOWED_GET_PROPER_TYPES diff --git a/test-data/unit/check-typevar-tuple.test b/test-data/unit/check-typevar-tuple.test index 754151ffb559..20f3ef966bdc 100644 --- a/test-data/unit/check-typevar-tuple.test +++ b/test-data/unit/check-typevar-tuple.test @@ -2305,18 +2305,21 @@ def higher_order(f: _CallableValue) -> None: ... def good1(*args: int) -> None: ... def good2(*args: str) -> int: ... -def bad1(a: str, b: int, /) -> None: ... -def bad2(c: bytes, *args: int) -> str: ... -def bad3(*, d: str) -> int: ... -def bad4(**kwargs: None) -> None: ... +# These are special-cased for *args: Any (as opposite to *args: object) +def ok1(a: str, b: int, /) -> None: ... +def ok2(c: bytes, *args: int) -> str: ... + +def bad1(*, d: str) -> int: ... +def bad2(**kwargs: None) -> None: ... higher_order(good1) higher_order(good2) -higher_order(bad1) # E: Argument 1 to "higher_order" has incompatible type "Callable[[str, int], None]"; expected "Callable[[VarArg(Any)], Any]" -higher_order(bad2) # E: Argument 1 to "higher_order" has incompatible type "Callable[[bytes, VarArg(int)], str]"; expected "Callable[[VarArg(Any)], Any]" -higher_order(bad3) # E: Argument 1 to "higher_order" has incompatible type "Callable[[NamedArg(str, 'd')], int]"; expected "Callable[[VarArg(Any)], Any]" -higher_order(bad4) # E: Argument 1 to "higher_order" has incompatible type "Callable[[KwArg(None)], None]"; expected "Callable[[VarArg(Any)], Any]" +higher_order(ok1) +higher_order(ok2) + +higher_order(bad1) # E: Argument 1 to "higher_order" has incompatible type "Callable[[NamedArg(str, 'd')], int]"; expected "Callable[[VarArg(Any)], Any]" +higher_order(bad2) # E: Argument 1 to "higher_order" has incompatible type "Callable[[KwArg(None)], None]"; expected "Callable[[VarArg(Any)], Any]" [builtins fixtures/tuple.pyi] [case testAliasToCallableWithUnpack2] @@ -2513,3 +2516,89 @@ x4: Foo[Unpack[tuple[str, ...]]] y4: Foo[Unpack[tuple[int, int]]] x4 is y4 # E: Non-overlapping identity check (left operand type: "Foo[Unpack[Tuple[str, ...]]]", right operand type: "Foo[int, int]") [builtins fixtures/tuple.pyi] + +[case testTypeVarTupleErasureNormalized] +from typing import TypeVarTuple, Unpack, Generic, Union +from collections.abc import Callable + +Args = TypeVarTuple("Args") + +class Built(Generic[Unpack[Args]]): + pass + +def example( + fn: Union[Built[Unpack[Args]], Callable[[Unpack[Args]], None]] +) -> Built[Unpack[Args]]: ... + +@example +def command() -> None: + return +reveal_type(command) # N: Revealed type is "__main__.Built[()]" +[builtins fixtures/tuple.pyi] + +[case testTypeVarTupleSelfMappedPrefix] +from typing import TypeVarTuple, Generic, Unpack + +Ts = TypeVarTuple("Ts") +class Base(Generic[Unpack[Ts]]): + attr: tuple[Unpack[Ts]] + + @property + def prop(self) -> tuple[Unpack[Ts]]: + return self.attr + + def meth(self) -> tuple[Unpack[Ts]]: + return self.attr + +Ss = TypeVarTuple("Ss") +class Derived(Base[str, Unpack[Ss]]): + def test(self) -> None: + reveal_type(self.attr) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]" + reveal_type(self.prop) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]" + reveal_type(self.meth()) # N: Revealed type is "Tuple[builtins.str, Unpack[Ss`1]]" +[builtins fixtures/property.pyi] + +[case testTypeVarTupleProtocolPrefix] +from typing import Protocol, Unpack, TypeVarTuple + +Ts = TypeVarTuple("Ts") +class A(Protocol[Unpack[Ts]]): + def f(self, z: str, *args: Unpack[Ts]) -> None: ... + +class C: + def f(self, z: str, x: int) -> None: ... + +def f(x: A[Unpack[Ts]]) -> tuple[Unpack[Ts]]: ... + +reveal_type(f(C())) # N: Revealed type is "Tuple[builtins.int]" +[builtins fixtures/tuple.pyi] + +[case testTypeVarTupleHomogeneousCallableNormalized] +from typing import Generic, Unpack, TypeVarTuple + +Ts = TypeVarTuple("Ts") +class C(Generic[Unpack[Ts]]): + def foo(self, *args: Unpack[Ts]) -> None: ... + +c: C[Unpack[tuple[int, ...]]] +reveal_type(c.foo) # N: Revealed type is "def (*args: builtins.int)" +[builtins fixtures/tuple.pyi] + +[case testTypeVarTupleJoinInstanceTypeVar] +from typing import Any, Unpack, TypeVarTuple, TypeVar + +T = TypeVar("T") +Ts = TypeVarTuple("Ts") + +def join(x: T, y: T) -> T: ... +def test(xs: tuple[Unpack[Ts]], xsi: tuple[int, Unpack[Ts]]) -> None: + a: tuple[Any, ...] + reveal_type(join(xs, a)) # N: Revealed type is "builtins.tuple[Any, ...]" + reveal_type(join(a, xs)) # N: Revealed type is "builtins.tuple[Any, ...]" + aa: tuple[Unpack[tuple[Any, ...]]] + reveal_type(join(xs, aa)) # N: Revealed type is "builtins.tuple[Any, ...]" + reveal_type(join(aa, xs)) # N: Revealed type is "builtins.tuple[Any, ...]" + ai: tuple[int, Unpack[tuple[Any, ...]]] + reveal_type(join(xsi, ai)) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]" + reveal_type(join(ai, xsi)) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[Any, ...]]]" +[builtins fixtures/tuple.pyi]