Skip to content

Commit a706914

Browse files
Support ParamSpec mapping with functools.partial (#17355)
Follow-up for #17323, resolving a false positive discovered there. Closes #17960. This enables use of `functools.partial` to bind some `*args` or `**kwargs` on a callable typed with `ParamSpec`. --------- Co-authored-by: Shantanu Jain <[email protected]>
1 parent e7db89c commit a706914

File tree

4 files changed

+174
-7
lines changed

4 files changed

+174
-7
lines changed

mypy/checkexpr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2377,7 +2377,11 @@ def check_argument_count(
23772377
# Positional argument when expecting a keyword argument.
23782378
self.msg.too_many_positional_arguments(callee, context)
23792379
ok = False
2380-
elif callee.param_spec() is not None and not formal_to_actual[i]:
2380+
elif (
2381+
callee.param_spec() is not None
2382+
and not formal_to_actual[i]
2383+
and callee.special_sig != "partial"
2384+
):
23812385
self.msg.too_few_arguments(callee, context, actual_names)
23822386
ok = False
23832387
return ok

mypy/plugins/functools.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
import mypy.plugin
99
import mypy.semanal
1010
from mypy.argmap import map_actuals_to_formals
11-
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
11+
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, NameExpr, Var
1212
from mypy.plugins.common import add_method_to_class
1313
from mypy.typeops import get_all_type_vars
1414
from mypy.types import (
1515
AnyType,
1616
CallableType,
1717
Instance,
1818
Overloaded,
19+
ParamSpecFlavor,
20+
ParamSpecType,
1921
Type,
2022
TypeOfAny,
2123
TypeVarType,
@@ -202,6 +204,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
202204
continue
203205
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})
204206

207+
# special_sig="partial" allows omission of args/kwargs typed with ParamSpec
205208
defaulted = fn_type.copy_modified(
206209
arg_kinds=[
207210
(
@@ -218,6 +221,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
218221
# Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
219222
if tv.id in can_infer_ids or not isinstance(tv, TypeVarType)
220223
],
224+
special_sig="partial",
221225
)
222226
if defaulted.line < 0:
223227
# Make up a line number if we don't have one
@@ -296,10 +300,19 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
296300
arg_kinds=partial_kinds,
297301
arg_names=partial_names,
298302
ret_type=ret_type,
303+
special_sig="partial",
299304
)
300305

301306
ret = ctx.api.named_generic_type(PARTIAL, [ret_type])
302307
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
308+
if partially_applied.param_spec():
309+
assert ret.extra_attrs is not None # copy_with_extra_attr above ensures this
310+
attrs = ret.extra_attrs.copy()
311+
if ArgKind.ARG_STAR in actual_arg_kinds:
312+
attrs.immutable.add("__mypy_partial_paramspec_args_bound")
313+
if ArgKind.ARG_STAR2 in actual_arg_kinds:
314+
attrs.immutable.add("__mypy_partial_paramspec_kwargs_bound")
315+
ret.extra_attrs = attrs
303316
return ret
304317

305318

@@ -314,7 +327,8 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
314327
):
315328
return ctx.default_return_type
316329

317-
partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
330+
extra_attrs = ctx.type.extra_attrs
331+
partial_type = get_proper_type(extra_attrs.attrs["__mypy_partial"])
318332
if len(ctx.arg_types) != 2: # *args, **kwargs
319333
return ctx.default_return_type
320334

@@ -332,11 +346,36 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
332346
actual_arg_kinds.append(ctx.arg_kinds[i][j])
333347
actual_arg_names.append(ctx.arg_names[i][j])
334348

335-
result = ctx.api.expr_checker.check_call(
349+
result, _ = ctx.api.expr_checker.check_call(
336350
callee=partial_type,
337351
args=actual_args,
338352
arg_kinds=actual_arg_kinds,
339353
arg_names=actual_arg_names,
340354
context=ctx.context,
341355
)
342-
return result[0]
356+
if not isinstance(partial_type, CallableType) or partial_type.param_spec() is None:
357+
return result
358+
359+
args_bound = "__mypy_partial_paramspec_args_bound" in extra_attrs.immutable
360+
kwargs_bound = "__mypy_partial_paramspec_kwargs_bound" in extra_attrs.immutable
361+
362+
passed_paramspec_parts = [
363+
arg.node.type
364+
for arg in actual_args
365+
if isinstance(arg, NameExpr)
366+
and isinstance(arg.node, Var)
367+
and isinstance(arg.node.type, ParamSpecType)
368+
]
369+
# ensure *args: P.args
370+
args_passed = any(part.flavor == ParamSpecFlavor.ARGS for part in passed_paramspec_parts)
371+
if not args_bound and not args_passed:
372+
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
373+
elif args_bound and args_passed:
374+
ctx.api.expr_checker.msg.too_many_arguments(partial_type, ctx.context)
375+
376+
# ensure **kwargs: P.kwargs
377+
kwargs_passed = any(part.flavor == ParamSpecFlavor.KWARGS for part in passed_paramspec_parts)
378+
if not kwargs_bound and not kwargs_passed:
379+
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
380+
381+
return result

mypy/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,8 +1827,8 @@ class CallableType(FunctionLike):
18271827
"implicit", # Was this type implicitly generated instead of explicitly
18281828
# specified by the user?
18291829
"special_sig", # Non-None for signatures that require special handling
1830-
# (currently only value is 'dict' for a signature similar to
1831-
# 'dict')
1830+
# (currently only values are 'dict' for a signature similar to
1831+
# 'dict' and 'partial' for a `functools.partial` evaluation)
18321832
"from_type_type", # Was this callable generated by analyzing Type[...]
18331833
# instantiation?
18341834
"bound_args", # Bound type args, mostly unused but may be useful for

test-data/unit/check-parameter-specification.test

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,3 +2338,127 @@ reveal_type(handle_reversed(Child())) # N: Revealed type is "builtins.str"
23382338
reveal_type(handle_reversed(NotChild())) # N: Revealed type is "builtins.str"
23392339

23402340
[builtins fixtures/paramspec.pyi]
2341+
2342+
[case testBindPartial]
2343+
from functools import partial
2344+
from typing_extensions import ParamSpec
2345+
from typing import Callable, TypeVar
2346+
2347+
P = ParamSpec("P")
2348+
T = TypeVar("T")
2349+
2350+
def run(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2351+
func2 = partial(func, **kwargs)
2352+
return func2(*args)
2353+
2354+
def run2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2355+
func2 = partial(func, *args)
2356+
return func2(**kwargs)
2357+
2358+
def run3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2359+
func2 = partial(func, *args, **kwargs)
2360+
return func2()
2361+
2362+
def run4(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2363+
func2 = partial(func, *args, **kwargs)
2364+
return func2(**kwargs)
2365+
2366+
def run_bad(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2367+
func2 = partial(func, *args, **kwargs)
2368+
return func2(*args) # E: Too many arguments
2369+
2370+
def run_bad2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2371+
func2 = partial(func, **kwargs)
2372+
return func2(**kwargs) # E: Too few arguments
2373+
2374+
def run_bad3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
2375+
func2 = partial(func, *args)
2376+
return func2() # E: Too few arguments
2377+
2378+
[builtins fixtures/paramspec.pyi]
2379+
2380+
[case testBindPartialConcatenate]
2381+
from functools import partial
2382+
from typing_extensions import Concatenate, ParamSpec
2383+
from typing import Callable, TypeVar
2384+
2385+
P = ParamSpec("P")
2386+
T = TypeVar("T")
2387+
2388+
def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2389+
func2 = partial(func, 1, **kwargs)
2390+
return func2(*args)
2391+
2392+
def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2393+
func2 = partial(func, **kwargs)
2394+
p = [""]
2395+
func2(1, *p) # E: Too few arguments \
2396+
# E: Argument 2 has incompatible type "*List[str]"; expected "P.args"
2397+
func2(1, 2, *p) # E: Too few arguments \
2398+
# E: Argument 2 has incompatible type "int"; expected "P.args" \
2399+
# E: Argument 3 has incompatible type "*List[str]"; expected "P.args"
2400+
func2(1, *args, *p) # E: Argument 3 has incompatible type "*List[str]"; expected "P.args"
2401+
func2(1, *p, *args) # E: Argument 2 has incompatible type "*List[str]"; expected "P.args"
2402+
return func2(1, *args)
2403+
2404+
def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2405+
func2 = partial(func, 1, *args)
2406+
d = {"":""}
2407+
func2(**d) # E: Too few arguments \
2408+
# E: Argument 1 has incompatible type "**Dict[str, str]"; expected "P.kwargs"
2409+
return func2(**kwargs)
2410+
2411+
def run4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2412+
func2 = partial(func, 1)
2413+
return func2(*args, **kwargs)
2414+
2415+
def run5(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2416+
func2 = partial(func, 1, *args, **kwargs)
2417+
func2()
2418+
return func2(**kwargs)
2419+
2420+
def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2421+
func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int"
2422+
return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args"
2423+
2424+
def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2425+
func2 = partial(func, 1, *args)
2426+
func2() # E: Too few arguments
2427+
func2(*args, **kwargs) # E: Too many arguments
2428+
return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args"
2429+
2430+
def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2431+
func2 = partial(func, 1, **kwargs)
2432+
func2() # E: Too few arguments
2433+
return func2(1, *args) # E: Argument 1 has incompatible type "int"; expected "P.args"
2434+
2435+
def run_bad4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2436+
func2 = partial(func, 1)
2437+
func2() # E: Too few arguments
2438+
func2(*args) # E: Too few arguments
2439+
func2(1, *args) # E: Too few arguments \
2440+
# E: Argument 1 has incompatible type "int"; expected "P.args"
2441+
func2(1, **kwargs) # E: Too few arguments \
2442+
# E: Argument 1 has incompatible type "int"; expected "P.args"
2443+
return func2(**kwargs) # E: Too few arguments
2444+
2445+
[builtins fixtures/paramspec.pyi]
2446+
2447+
[case testOtherVarArgs]
2448+
from functools import partial
2449+
from typing_extensions import Concatenate, ParamSpec
2450+
from typing import Callable, TypeVar, Tuple
2451+
2452+
P = ParamSpec("P")
2453+
T = TypeVar("T")
2454+
2455+
def run(func: Callable[Concatenate[int, str, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
2456+
func2 = partial(func, **kwargs)
2457+
args_prefix: Tuple[int, str] = (1, 'a')
2458+
func2(*args_prefix) # E: Too few arguments
2459+
func2(*args, *args_prefix) # E: Argument 1 has incompatible type "*P.args"; expected "int" \
2460+
# E: Argument 1 has incompatible type "*P.args"; expected "str" \
2461+
# E: Argument 2 has incompatible type "*Tuple[int, str]"; expected "P.args"
2462+
return func2(*args_prefix, *args)
2463+
2464+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)