Skip to content

Commit 07d8878

Browse files
A5rockshauntsaninjacdce8p
authored
Basic ParamSpec Concatenate and literal support (#11847)
This PR adds a new Parameters proper type to represent ParamSpec parameters (more about this in the PR), along with supporting the Concatenate operator. Closes #11833 Closes #12276 Closes #12257 Refs #8645 External ref python/typeshed#4827 Co-authored-by: Shantanu <[email protected]> Co-authored-by: Marc Mueller <[email protected]>
1 parent 4ff8d04 commit 07d8878

27 files changed

+1473
-121
lines changed

docs/source/config_file.rst

+7
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,13 @@ section of the command line docs.
676676
from foo import bar
677677
__all__ = ['bar']
678678
679+
.. confval:: strict_concatenate
680+
681+
:type: boolean
682+
:default: False
683+
684+
Make arguments prepended via ``Concatenate`` be truly positional-only.
685+
679686
.. confval:: strict_equality
680687

681688
:type: boolean

mypy/applytype.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mypy.expandtype import expand_type
66
from mypy.types import (
77
Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types,
8-
TypeVarLikeType, ProperType, ParamSpecType, get_proper_type
8+
TypeVarLikeType, ProperType, ParamSpecType, Parameters, get_proper_type
99
)
1010
from mypy.nodes import Context
1111

@@ -94,7 +94,7 @@ def apply_generic_arguments(
9494
nt = id_to_type.get(param_spec.id)
9595
if nt is not None:
9696
nt = get_proper_type(nt)
97-
if isinstance(nt, CallableType):
97+
if isinstance(nt, CallableType) or isinstance(nt, Parameters):
9898
callable = callable.expand_param_spec(nt)
9999

100100
# Apply arguments to argument types.

mypy/checker.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5224,7 +5224,7 @@ def check_subtype(self,
52245224
code: Optional[ErrorCode] = None,
52255225
outer_context: Optional[Context] = None) -> bool:
52265226
"""Generate an error if the subtype is not compatible with supertype."""
5227-
if is_subtype(subtype, supertype):
5227+
if is_subtype(subtype, supertype, options=self.options):
52285228
return True
52295229

52305230
if isinstance(msg, ErrorMessage):
@@ -5260,6 +5260,7 @@ def check_subtype(self,
52605260
self.msg.note(note, context, code=code)
52615261
if note_msg:
52625262
self.note(note_msg, context, code=code)
5263+
self.msg.maybe_note_concatenate_pos_args(subtype, supertype, context, code=code)
52635264
if (isinstance(supertype, Instance) and supertype.type.is_protocol and
52645265
isinstance(subtype, (Instance, TupleType, TypedDictType))):
52655266
self.msg.report_protocol_problems(subtype, supertype, context, code=code)

mypy/checkexpr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1556,7 +1556,7 @@ def check_arg(self,
15561556
isinstance(callee_type.item, Instance) and
15571557
(callee_type.item.type.is_abstract or callee_type.item.type.is_protocol)):
15581558
self.msg.concrete_only_call(callee_type, context)
1559-
elif not is_subtype(caller_type, callee_type):
1559+
elif not is_subtype(caller_type, callee_type, options=self.chk.options):
15601560
if self.chk.should_suppress_optional_error([caller_type, callee_type]):
15611561
return
15621562
code = messages.incompatible_argument(n,

mypy/constraints.py

+75-7
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
99
UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
1010
ProperType, ParamSpecType, get_proper_type, TypeAliasType, is_union_with_any,
11-
UnpackType, callable_with_ellipsis, TUPLE_LIKE_INSTANCE_NAMES,
11+
UnpackType, callable_with_ellipsis, Parameters, TUPLE_LIKE_INSTANCE_NAMES,
1212
)
1313
from mypy.maptype import map_instance_to_supertype
1414
import mypy.subtypes
@@ -406,6 +406,9 @@ def visit_param_spec(self, template: ParamSpecType) -> List[Constraint]:
406406
def visit_unpack_type(self, template: UnpackType) -> List[Constraint]:
407407
raise NotImplementedError
408408

409+
def visit_parameters(self, template: Parameters) -> List[Constraint]:
410+
raise RuntimeError("Parameters cannot be constrained to")
411+
409412
# Non-leaf types
410413

411414
def visit_instance(self, template: Instance) -> List[Constraint]:
@@ -446,7 +449,7 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
446449
# N.B: We use zip instead of indexing because the lengths might have
447450
# mismatches during daemon reprocessing.
448451
for tvar, mapped_arg, instance_arg in zip(tvars, mapped.args, instance.args):
449-
# TODO: ParamSpecType
452+
# TODO(PEP612): More ParamSpec work (or is Parameters the only thing accepted)
450453
if isinstance(tvar, TypeVarType):
451454
# The constraints for generic type parameters depend on variance.
452455
# Include constraints from both directions if invariant.
@@ -456,6 +459,27 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
456459
if tvar.variance != COVARIANT:
457460
res.extend(infer_constraints(
458461
mapped_arg, instance_arg, neg_op(self.direction)))
462+
elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType):
463+
suffix = get_proper_type(instance_arg)
464+
465+
if isinstance(suffix, CallableType):
466+
prefix = mapped_arg.prefix
467+
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
468+
suffix = suffix.copy_modified(from_concatenate=from_concat)
469+
470+
if isinstance(suffix, Parameters) or isinstance(suffix, CallableType):
471+
# no such thing as variance for ParamSpecs
472+
# TODO: is there a case I am missing?
473+
# TODO: constraints between prefixes
474+
prefix = mapped_arg.prefix
475+
suffix = suffix.copy_modified(
476+
suffix.arg_types[len(prefix.arg_types):],
477+
suffix.arg_kinds[len(prefix.arg_kinds):],
478+
suffix.arg_names[len(prefix.arg_names):])
479+
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))
480+
elif isinstance(suffix, ParamSpecType):
481+
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))
482+
459483
return res
460484
elif (self.direction == SUPERTYPE_OF and
461485
instance.type.has_base(template.type.fullname)):
@@ -464,7 +488,6 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
464488
# N.B: We use zip instead of indexing because the lengths might have
465489
# mismatches during daemon reprocessing.
466490
for tvar, mapped_arg, template_arg in zip(tvars, mapped.args, template.args):
467-
# TODO: ParamSpecType
468491
if isinstance(tvar, TypeVarType):
469492
# The constraints for generic type parameters depend on variance.
470493
# Include constraints from both directions if invariant.
@@ -474,6 +497,28 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
474497
if tvar.variance != COVARIANT:
475498
res.extend(infer_constraints(
476499
template_arg, mapped_arg, neg_op(self.direction)))
500+
elif (isinstance(tvar, ParamSpecType) and
501+
isinstance(template_arg, ParamSpecType)):
502+
suffix = get_proper_type(mapped_arg)
503+
504+
if isinstance(suffix, CallableType):
505+
prefix = template_arg.prefix
506+
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
507+
suffix = suffix.copy_modified(from_concatenate=from_concat)
508+
509+
if isinstance(suffix, Parameters) or isinstance(suffix, CallableType):
510+
# no such thing as variance for ParamSpecs
511+
# TODO: is there a case I am missing?
512+
# TODO: constraints between prefixes
513+
prefix = template_arg.prefix
514+
515+
suffix = suffix.copy_modified(
516+
suffix.arg_types[len(prefix.arg_types):],
517+
suffix.arg_kinds[len(prefix.arg_kinds):],
518+
suffix.arg_names[len(prefix.arg_names):])
519+
res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix))
520+
elif isinstance(suffix, ParamSpecType):
521+
res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix))
477522
return res
478523
if (template.type.is_protocol and self.direction == SUPERTYPE_OF and
479524
# We avoid infinite recursion for structural subtypes by checking
@@ -564,11 +609,34 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
564609
# Negate direction due to function argument type contravariance.
565610
res.extend(infer_constraints(t, a, neg_op(self.direction)))
566611
else:
612+
# sometimes, it appears we try to get constraints between two paramspec callables?
567613
# TODO: Direction
568-
# TODO: Deal with arguments that come before param spec ones?
569-
res.append(Constraint(param_spec.id,
570-
SUBTYPE_OF,
571-
cactual.copy_modified(ret_type=NoneType())))
614+
# TODO: check the prefixes match
615+
prefix = param_spec.prefix
616+
prefix_len = len(prefix.arg_types)
617+
cactual_ps = cactual.param_spec()
618+
619+
if not cactual_ps:
620+
res.append(Constraint(param_spec.id,
621+
SUBTYPE_OF,
622+
cactual.copy_modified(
623+
arg_types=cactual.arg_types[prefix_len:],
624+
arg_kinds=cactual.arg_kinds[prefix_len:],
625+
arg_names=cactual.arg_names[prefix_len:],
626+
ret_type=NoneType())))
627+
else:
628+
res.append(Constraint(param_spec.id, SUBTYPE_OF, cactual_ps))
629+
630+
# compare prefixes
631+
cactual_prefix = cactual.copy_modified(
632+
arg_types=cactual.arg_types[:prefix_len],
633+
arg_kinds=cactual.arg_kinds[:prefix_len],
634+
arg_names=cactual.arg_names[:prefix_len])
635+
636+
# TODO: see above "FIX" comments for param_spec is None case
637+
# TODO: this assume positional arguments
638+
for t, a in zip(prefix.arg_types, cactual_prefix.arg_types):
639+
res.extend(infer_constraints(t, a, neg_op(self.direction)))
572640

573641
template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
574642
if template.type_guard is not None:

mypy/erasetype.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType,
55
CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType,
66
DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType,
7-
get_proper_type, get_proper_types, TypeAliasType, ParamSpecType, UnpackType
7+
get_proper_type, get_proper_types, TypeAliasType, ParamSpecType, Parameters, UnpackType
88
)
99
from mypy.nodes import ARG_STAR, ARG_STAR2
1010

@@ -59,6 +59,9 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
5959
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
6060
return AnyType(TypeOfAny.special_form)
6161

62+
def visit_parameters(self, t: Parameters) -> ProperType:
63+
raise RuntimeError("Parameters should have been bound to a class")
64+
6265
def visit_unpack_type(self, t: UnpackType) -> ProperType:
6366
raise NotImplementedError
6467

mypy/expandtype.py

+40-8
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
NoneType, Overloaded, TupleType, TypedDictType, UnionType,
66
ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId,
77
FunctionLike, TypeVarType, LiteralType, get_proper_type, ProperType,
8-
TypeAliasType, ParamSpecType, TypeVarLikeType, UnpackType
8+
TypeAliasType, ParamSpecType, TypeVarLikeType, Parameters, ParamSpecFlavor,
9+
UnpackType
910
)
1011

1112

@@ -101,15 +102,41 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
101102
repl = get_proper_type(self.variables.get(t.id, t))
102103
if isinstance(repl, Instance):
103104
inst = repl
105+
# Return copy of instance with type erasure flag on.
106+
# TODO: what does prefix mean in this case?
107+
# TODO: why does this case even happen? Instances aren't plural.
104108
return Instance(inst.type, inst.args, line=inst.line, column=inst.column)
105109
elif isinstance(repl, ParamSpecType):
106-
return repl.with_flavor(t.flavor)
110+
return repl.copy_modified(flavor=t.flavor, prefix=t.prefix.copy_modified(
111+
arg_types=t.prefix.arg_types + repl.prefix.arg_types,
112+
arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
113+
arg_names=t.prefix.arg_names + repl.prefix.arg_names,
114+
))
115+
elif isinstance(repl, Parameters) or isinstance(repl, CallableType):
116+
# if the paramspec is *P.args or **P.kwargs:
117+
if t.flavor != ParamSpecFlavor.BARE:
118+
assert isinstance(repl, CallableType), "Should not be able to get here."
119+
# Is this always the right thing to do?
120+
param_spec = repl.param_spec()
121+
if param_spec:
122+
return param_spec.with_flavor(t.flavor)
123+
else:
124+
return repl
125+
else:
126+
return Parameters(t.prefix.arg_types + repl.arg_types,
127+
t.prefix.arg_kinds + repl.arg_kinds,
128+
t.prefix.arg_names + repl.arg_names,
129+
variables=[*t.prefix.variables, *repl.variables])
107130
else:
131+
# TODO: should this branch be removed? better not to fail silently
108132
return repl
109133

110134
def visit_unpack_type(self, t: UnpackType) -> Type:
111135
raise NotImplementedError
112136

137+
def visit_parameters(self, t: Parameters) -> Type:
138+
return t.copy_modified(arg_types=self.expand_types(t.arg_types))
139+
113140
def visit_callable_type(self, t: CallableType) -> Type:
114141
param_spec = t.param_spec()
115142
if param_spec is not None:
@@ -121,13 +148,18 @@ def visit_callable_type(self, t: CallableType) -> Type:
121148
# must expand both of them with all the argument types,
122149
# kinds and names in the replacement. The return type in
123150
# the replacement is ignored.
124-
if isinstance(repl, CallableType):
151+
if isinstance(repl, CallableType) or isinstance(repl, Parameters):
125152
# Substitute *args: P.args, **kwargs: P.kwargs
126-
t = t.expand_param_spec(repl)
127-
# TODO: Substitute remaining arg types
128-
return t.copy_modified(ret_type=t.ret_type.accept(self),
129-
type_guard=(t.type_guard.accept(self)
130-
if t.type_guard is not None else None))
153+
prefix = param_spec.prefix
154+
# we need to expand the types in the prefix, so might as well
155+
# not get them in the first place
156+
t = t.expand_param_spec(repl, no_prefix=True)
157+
return t.copy_modified(
158+
arg_types=self.expand_types(prefix.arg_types) + t.arg_types,
159+
arg_kinds=prefix.arg_kinds + t.arg_kinds,
160+
arg_names=prefix.arg_names + t.arg_names,
161+
ret_type=t.ret_type.accept(self),
162+
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None))
131163

132164
return t.copy_modified(arg_types=self.expand_types(t.arg_types),
133165
ret_type=t.ret_type.accept(self),

mypy/fixup.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
CallableType, Instance, Overloaded, TupleType, TypedDictType,
1212
TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType,
1313
TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, ParamSpecType,
14-
UnpackType,
14+
Parameters, UnpackType,
1515
)
1616
from mypy.visitor import NodeVisitor
1717
from mypy.lookup import lookup_fully_qualified
@@ -255,6 +255,11 @@ def visit_param_spec(self, p: ParamSpecType) -> None:
255255
def visit_unpack_type(self, u: UnpackType) -> None:
256256
u.type.accept(self)
257257

258+
def visit_parameters(self, p: Parameters) -> None:
259+
for argt in p.arg_types:
260+
if argt is not None:
261+
argt.accept(self)
262+
258263
def visit_unbound_type(self, o: UnboundType) -> None:
259264
for a in o.args:
260265
a.accept(self)

mypy/indirection.py

+3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def visit_param_spec(self, t: types.ParamSpecType) -> Set[str]:
7070
def visit_unpack_type(self, t: types.UnpackType) -> Set[str]:
7171
return t.type.accept(self)
7272

73+
def visit_parameters(self, t: types.Parameters) -> Set[str]:
74+
return self._visit(t.arg_types)
75+
7376
def visit_instance(self, t: types.Instance) -> Set[str]:
7477
out = self._visit(t.args)
7578
if t.type:

mypy/join.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType,
88
TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType,
99
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type,
10-
ProperType, get_proper_types, TypeAliasType, PlaceholderType, ParamSpecType,
10+
ProperType, get_proper_types, TypeAliasType, PlaceholderType, ParamSpecType, Parameters,
1111
UnpackType
1212
)
1313
from mypy.maptype import map_instance_to_supertype
@@ -260,6 +260,12 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType:
260260
def visit_unpack_type(self, t: UnpackType) -> UnpackType:
261261
raise NotImplementedError
262262

263+
def visit_parameters(self, t: Parameters) -> ProperType:
264+
if self.s == t:
265+
return t
266+
else:
267+
return self.default(self.s)
268+
263269
def visit_instance(self, t: Instance) -> ProperType:
264270
if isinstance(self.s, Instance):
265271
if self.instance_joiner is None:

mypy/main.py

+4
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,10 @@ def add_invertible_flag(flag: str,
679679
" non-overlapping types",
680680
group=strictness_group)
681681

682+
add_invertible_flag('--strict-concatenate', default=False, strict_flag=True,
683+
help="Make arguments prepended via Concatenate be truly positional-only",
684+
group=strictness_group)
685+
682686
strict_help = "Strict mode; enables the following flags: {}".format(
683687
", ".join(strict_flag_names))
684688
strictness_group.add_argument(

mypy/meet.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType,
77
UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType,
88
ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardedType,
9-
ParamSpecType, UnpackType,
9+
ParamSpecType, Parameters, UnpackType,
1010
)
1111
from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype
1212
from mypy.erasetype import erase_type
@@ -509,6 +509,17 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType:
509509
def visit_unpack_type(self, t: UnpackType) -> ProperType:
510510
raise NotImplementedError
511511

512+
def visit_parameters(self, t: Parameters) -> ProperType:
513+
# TODO: is this the right variance?
514+
if isinstance(self.s, Parameters) or isinstance(self.s, CallableType):
515+
if len(t.arg_types) != len(self.s.arg_types):
516+
return self.default(self.s)
517+
return t.copy_modified(
518+
arg_types=[meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)]
519+
)
520+
else:
521+
return self.default(self.s)
522+
512523
def visit_instance(self, t: Instance) -> ProperType:
513524
if isinstance(self.s, Instance):
514525
if t.type == self.s.type:

0 commit comments

Comments
 (0)