Skip to content

Commit 876f636

Browse files
authored
Improve support for functools.partial of overloaded callable protocol (#18639)
Resolves #18637 Mypy's behaviour here is not correct (see test case), but this PR makes mypy's behaviour match what it used to be before we added the functools.partial plugin Support for overloads tracked in #17585
1 parent 5bb681a commit 876f636

File tree

2 files changed

+65
-41
lines changed

2 files changed

+65
-41
lines changed

mypy/checker.py

+48-41
Original file line numberDiff line numberDiff line change
@@ -703,50 +703,57 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
703703
def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> CallableType | None:
704704
"""Get type as seen by an overload item caller."""
705705
inner_type = get_proper_type(inner_type)
706-
outer_type: CallableType | None = None
707-
if inner_type is not None and not isinstance(inner_type, AnyType):
708-
if isinstance(inner_type, TypeVarLikeType):
709-
inner_type = get_proper_type(inner_type.upper_bound)
710-
if isinstance(inner_type, TypeType):
711-
inner_type = get_proper_type(
712-
self.expr_checker.analyze_type_type_callee(inner_type.item, ctx)
713-
)
706+
outer_type: FunctionLike | None = None
707+
if inner_type is None or isinstance(inner_type, AnyType):
708+
return None
709+
if isinstance(inner_type, TypeVarLikeType):
710+
inner_type = get_proper_type(inner_type.upper_bound)
711+
if isinstance(inner_type, TypeType):
712+
inner_type = get_proper_type(
713+
self.expr_checker.analyze_type_type_callee(inner_type.item, ctx)
714+
)
714715

715-
if isinstance(inner_type, CallableType):
716-
outer_type = inner_type
717-
elif isinstance(inner_type, Instance):
718-
inner_call = get_proper_type(
719-
analyze_member_access(
720-
name="__call__",
721-
typ=inner_type,
722-
context=ctx,
723-
is_lvalue=False,
724-
is_super=False,
725-
is_operator=True,
726-
msg=self.msg,
727-
original_type=inner_type,
728-
chk=self,
729-
)
716+
if isinstance(inner_type, FunctionLike):
717+
outer_type = inner_type
718+
elif isinstance(inner_type, Instance):
719+
inner_call = get_proper_type(
720+
analyze_member_access(
721+
name="__call__",
722+
typ=inner_type,
723+
context=ctx,
724+
is_lvalue=False,
725+
is_super=False,
726+
is_operator=True,
727+
msg=self.msg,
728+
original_type=inner_type,
729+
chk=self,
730730
)
731-
if isinstance(inner_call, CallableType):
732-
outer_type = inner_call
733-
elif isinstance(inner_type, UnionType):
734-
union_type = make_simplified_union(inner_type.items)
735-
if isinstance(union_type, UnionType):
736-
items = []
737-
for item in union_type.items:
738-
callable_item = self.extract_callable_type(item, ctx)
739-
if callable_item is None:
740-
break
741-
items.append(callable_item)
742-
else:
743-
joined_type = get_proper_type(join.join_type_list(items))
744-
if isinstance(joined_type, CallableType):
745-
outer_type = joined_type
731+
)
732+
if isinstance(inner_call, FunctionLike):
733+
outer_type = inner_call
734+
elif isinstance(inner_type, UnionType):
735+
union_type = make_simplified_union(inner_type.items)
736+
if isinstance(union_type, UnionType):
737+
items = []
738+
for item in union_type.items:
739+
callable_item = self.extract_callable_type(item, ctx)
740+
if callable_item is None:
741+
break
742+
items.append(callable_item)
746743
else:
747-
return self.extract_callable_type(union_type, ctx)
748-
if outer_type is None:
749-
self.msg.not_callable(inner_type, ctx)
744+
joined_type = get_proper_type(join.join_type_list(items))
745+
if isinstance(joined_type, FunctionLike):
746+
outer_type = joined_type
747+
else:
748+
return self.extract_callable_type(union_type, ctx)
749+
750+
if outer_type is None:
751+
self.msg.not_callable(inner_type, ctx)
752+
return None
753+
if isinstance(outer_type, Overloaded):
754+
return None
755+
756+
assert isinstance(outer_type, CallableType)
750757
return outer_type
751758

752759
def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:

test-data/unit/check-functools.test

+17
Original file line numberDiff line numberDiff line change
@@ -640,3 +640,20 @@ hp = partial(h, 1)
640640
reveal_type(hp(1)) # N: Revealed type is "builtins.int"
641641
hp("a") # E: Argument 1 to "h" has incompatible type "str"; expected "int"
642642
[builtins fixtures/tuple.pyi]
643+
644+
[case testFunctoolsPartialOverloadedCallableProtocol]
645+
from functools import partial
646+
from typing import Callable, Protocol, overload
647+
648+
class P(Protocol):
649+
@overload
650+
def __call__(self, x: int) -> int: ...
651+
@overload
652+
def __call__(self, x: str) -> str: ...
653+
654+
def f(x: P):
655+
reveal_type(partial(x, 1)()) # N: Revealed type is "builtins.int"
656+
657+
# TODO: but this is incorrect, predating the functools.partial plugin
658+
reveal_type(partial(x, "a")()) # N: Revealed type is "builtins.int"
659+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)