diff --git a/mypy/applytype.py b/mypy/applytype.py index e4a364ae383f..c3a6b53f2d42 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -4,7 +4,8 @@ import mypy.sametypes from mypy.expandtype import expand_type from mypy.types import ( - Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types + Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types, + TypeVarDef, TypeVarLikeDef, ProperType ) from mypy.nodes import Context @@ -29,20 +30,22 @@ def apply_generic_arguments( # Check that inferred type variable values are compatible with allowed # values and bounds. Also, promote subtype values to allowed values. types = get_proper_types(orig_types) - for i, type in enumerate(types): - assert not isinstance(type, PartialType), "Internal error: must never apply partial type" - values = get_proper_types(callable.variables[i].values) - if type is None: - continue + + # Create a map from type variable id to target type. + id_to_type = {} # type: Dict[TypeVarId, Type] + + def get_target_type(tvar: TypeVarLikeDef, type: ProperType) -> Optional[Type]: + assert isinstance(tvar, TypeVarDef), "TODO(shantanu) paramspec" + values = get_proper_types(tvar.values) if values: if isinstance(type, AnyType): - continue + return type if isinstance(type, TypeVarType) and type.values: # Allow substituting T1 for T if every allowed value of T1 # is also a legal value of T. if all(any(mypy.sametypes.is_same_type(v, v1) for v in values) for v1 in type.values): - continue + return type matching = [] for value in values: if mypy.subtypes.is_subtype(type, value): @@ -53,28 +56,26 @@ def apply_generic_arguments( for match in matching[1:]: if mypy.subtypes.is_subtype(match, best): best = match - types[i] = best - else: - if skip_unsatisfied: - types[i] = None - else: - report_incompatible_typevar_value(callable, type, callable.variables[i].name, - context) + return best + if skip_unsatisfied: + return None + report_incompatible_typevar_value(callable, type, tvar.name, context) else: - upper_bound = callable.variables[i].upper_bound + upper_bound = tvar.upper_bound if not mypy.subtypes.is_subtype(type, upper_bound): if skip_unsatisfied: - types[i] = None - else: - report_incompatible_typevar_value(callable, type, callable.variables[i].name, - context) + return None + report_incompatible_typevar_value(callable, type, tvar.name, context) + return type - # Create a map from type variable id to target type. - id_to_type = {} # type: Dict[TypeVarId, Type] - for i, tv in enumerate(tvars): - typ = types[i] - if typ: - id_to_type[tv.id] = typ + for tvar, type in zip(tvars, types): + assert not isinstance(type, PartialType), "Internal error: must never apply partial type" + if type is None: + continue + + target_type = get_target_type(tvar, type) + if target_type is not None: + id_to_type[tvar.id] = target_type # Apply arguments to argument types. arg_types = [expand_type(at, id_to_type) for at in callable.arg_types] diff --git a/mypy/checker.py b/mypy/checker.py index 75739fe87a00..93aba50425e7 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -50,7 +50,7 @@ erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, try_getting_str_literals_from_type, try_getting_int_literals_from_type, tuple_fallback, is_singleton_type, try_expanding_enum_to_union, - true_only, false_only, function_type, TypeVarExtractor, custom_special_method, + true_only, false_only, function_type, get_type_vars, custom_special_method, is_literal_type_like, ) from mypy import message_registry @@ -1389,15 +1389,14 @@ def expand_typevars(self, defn: FuncItem, typ: CallableType) -> List[Tuple[FuncItem, CallableType]]: # TODO use generator subst = [] # type: List[List[Tuple[TypeVarId, Type]]] - tvars = typ.variables or [] - tvars = tvars[:] + tvars = list(typ.variables) or [] if defn.info: # Class type variables tvars += defn.info.defn.type_vars or [] + # TODO(shantanu): audit for paramspec for tvar in tvars: - if tvar.values: - subst.append([(tvar.id, value) - for value in tvar.values]) + if isinstance(tvar, TypeVarDef) and tvar.values: + subst.append([(tvar.id, value) for value in tvar.values]) # Make a copy of the function to check for each combination of # value restricted type variables. (Except when running mypyc, # where we need one canonical version of the function.) @@ -5325,7 +5324,7 @@ def detach_callable(typ: CallableType) -> CallableType: appear_map = {} # type: Dict[str, List[int]] for i, inner_type in enumerate(type_list): - typevars_available = inner_type.accept(TypeVarExtractor()) + typevars_available = get_type_vars(inner_type) for var in typevars_available: if var.fullname not in appear_map: appear_map[var.fullname] = [] @@ -5335,7 +5334,7 @@ def detach_callable(typ: CallableType) -> CallableType: for var_name, appearances in appear_map.items(): used_type_var_names.add(var_name) - all_type_vars = typ.accept(TypeVarExtractor()) + all_type_vars = get_type_vars(typ) new_variables = [] for var in set(all_type_vars): if var.fullname not in used_type_var_names: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 5af114767357..f3bc8f22e003 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -31,6 +31,7 @@ DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr, YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode, + ParamSpecExpr, ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE, ) from mypy.literals import literal @@ -3969,6 +3970,9 @@ def visit_temp_node(self, e: TempNode) -> Type: def visit_type_var_expr(self, e: TypeVarExpr) -> Type: return AnyType(TypeOfAny.special_form) + def visit_paramspec_var_expr(self, e: ParamSpecExpr) -> Type: + return AnyType(TypeOfAny.special_form) + def visit_newtype_expr(self, e: NewTypeExpr) -> Type: return AnyType(TypeOfAny.special_form) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index c9a5a2c86d97..dd1ae9b6527b 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1,11 +1,11 @@ """Type checking of attribute access""" -from typing import cast, Callable, Optional, Union, List +from typing import cast, Callable, Optional, Union, List, Sequence from typing_extensions import TYPE_CHECKING from mypy.types import ( Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, TypeVarDef, - Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType, + TypeVarLikeDef, Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType, DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType ) from mypy.nodes import ( @@ -676,7 +676,7 @@ def analyze_class_attribute_access(itype: Instance, name: str, mx: MemberContext, override_info: Optional[TypeInfo] = None, - original_vars: Optional[List[TypeVarDef]] = None + original_vars: Optional[Sequence[TypeVarLikeDef]] = None ) -> Optional[Type]: """Analyze access to an attribute on a class object. @@ -839,7 +839,7 @@ def analyze_enum_class_attribute_access(itype: Instance, def add_class_tvars(t: ProperType, isuper: Optional[Instance], is_classmethod: bool, original_type: Type, - original_vars: Optional[List[TypeVarDef]] = None) -> Type: + original_vars: Optional[Sequence[TypeVarLikeDef]] = None) -> Type: """Instantiate type variables during analyze_class_attribute_access, e.g T and Q in the following: @@ -883,7 +883,7 @@ class B(A[str]): pass assert isuper is not None t = cast(CallableType, expand_type_by_instance(t, isuper)) freeze_type_vars(t) - return t.copy_modified(variables=tvars + t.variables) + return t.copy_modified(variables=list(tvars) + list(t.variables)) elif isinstance(t, Overloaded): return Overloaded([cast(CallableType, add_class_tvars(item, isuper, is_classmethod, original_type, diff --git a/mypy/expandtype.py b/mypy/expandtype.py index b805f3c0be83..ebdc8e328806 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -5,7 +5,7 @@ NoneType, TypeVarType, Overloaded, TupleType, TypedDictType, UnionType, ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, FunctionLike, TypeVarDef, LiteralType, get_proper_type, ProperType, - TypeAliasType) + TypeAliasType, ParamSpecDef, ParamSpecType) def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: @@ -40,9 +40,15 @@ def freshen_function_type_vars(callee: F) -> F: tvdefs = [] tvmap = {} # type: Dict[TypeVarId, Type] for v in callee.variables: - tvdef = TypeVarDef.new_unification_variable(v) + tvdef = v.new_unification_variable() tvdefs.append(tvdef) - tvmap[v.id] = TypeVarType(tvdef) + if isinstance(tvdef, TypeVarDef): + tvtype = TypeVarType(tvdef) # type: Type + elif isinstance(tvdef, ParamSpecDef): + tvtype = ParamSpecType(tvdef) + else: + assert False + tvmap[v.id] = tvtype fresh = cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvdefs) return cast(F, fresh) else: diff --git a/mypy/fixup.py b/mypy/fixup.py index 023df1e31331..2eaf904441fa 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -11,7 +11,8 @@ from mypy.types import ( CallableType, Instance, Overloaded, TupleType, TypedDictType, TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType, - TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny) + TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, TypeVarDef +) from mypy.visitor import NodeVisitor from mypy.lookup import lookup_fully_qualified @@ -183,10 +184,13 @@ def visit_callable_type(self, ct: CallableType) -> None: if ct.ret_type is not None: ct.ret_type.accept(self) for v in ct.variables: - if v.values: - for val in v.values: - val.accept(self) - v.upper_bound.accept(self) + if isinstance(v, TypeVarDef): + if v.values: + for val in v.values: + val.accept(self) + v.upper_bound.accept(self) + else: + assert False # TODO(shantanu): add param spec for arg in ct.bound_args: if arg: arg.accept(self) diff --git a/mypy/literals.py b/mypy/literals.py index 4779abf871c9..76b5c406da62 100644 --- a/mypy/literals.py +++ b/mypy/literals.py @@ -8,7 +8,7 @@ ConditionalExpr, EllipsisExpr, YieldFromExpr, YieldExpr, RevealExpr, SuperExpr, TypeApplication, LambdaExpr, ListComprehension, SetComprehension, DictionaryComprehension, GeneratorExpr, BackquoteExpr, TypeVarExpr, TypeAliasExpr, NamedTupleExpr, EnumCallExpr, - TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr, + TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr, ParamSpecExpr ) from mypy.visitor import ExpressionVisitor @@ -213,6 +213,9 @@ def visit_backquote_expr(self, e: BackquoteExpr) -> None: def visit_type_var_expr(self, e: TypeVarExpr) -> None: return None + def visit_paramspec_var_expr(self, e: ParamSpecExpr) -> None: + return None + def visit_type_alias_expr(self, e: TypeAliasExpr) -> None: return None diff --git a/mypy/messages.py b/mypy/messages.py index 8b689861548f..ea93f61734bb 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -21,7 +21,7 @@ from mypy.errors import Errors from mypy.types import ( Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType, LiteralType, - UnionType, NoneType, AnyType, Overloaded, FunctionLike, DeletedType, TypeType, + UnionType, NoneType, AnyType, Overloaded, FunctionLike, DeletedType, TypeType, TypeVarDef, UninhabitedType, TypeOfAny, UnboundType, PartialType, get_proper_type, ProperType, get_proper_types ) @@ -1868,16 +1868,20 @@ def [T <: int] f(self, x: int, y: T) -> None if tp.variables: tvars = [] for tvar in tp.variables: - upper_bound = get_proper_type(tvar.upper_bound) - if (isinstance(upper_bound, Instance) and - upper_bound.type.fullname != 'builtins.object'): - tvars.append('{} <: {}'.format(tvar.name, format_type_bare(upper_bound))) - elif tvar.values: - tvars.append('{} in ({})' - .format(tvar.name, ', '.join([format_type_bare(tp) - for tp in tvar.values]))) + if isinstance(tvar, TypeVarDef): + upper_bound = get_proper_type(tvar.upper_bound) + if (isinstance(upper_bound, Instance) and + upper_bound.type.fullname != 'builtins.object'): + tvars.append('{} <: {}'.format(tvar.name, format_type_bare(upper_bound))) + elif tvar.values: + tvars.append('{} in ({})' + .format(tvar.name, ', '.join([format_type_bare(tp) + for tp in tvar.values]))) + else: + tvars.append(tvar.name) else: - tvars.append(tvar.name) + # For other TypeVarLikeDefs, just use the repr + tvars.append(repr(tvar)) s = '[{}] {}'.format(', '.join(tvars), s) return 'def {}'.format(s) diff --git a/mypy/nodes.py b/mypy/nodes.py index 8ccb522323ba..f208b4f5e13c 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2043,23 +2043,10 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: CONTRAVARIANT = 2 # type: Final[int] -class TypeVarExpr(SymbolNode, Expression): - """Type variable expression TypeVar(...). - - This is also used to represent type variables in symbol tables. - - A type variable is not valid as a type unless bound in a TypeVarScope. - That happens within: - - 1. a generic class that uses the type variable as a type argument or - 2. a generic function that refers to the type variable in its signature. - """ - +class TypeVarLikeExpr(SymbolNode, Expression): + """Base class for TypeVarExpr and ParamSpecExpr.""" _name = '' _fullname = '' - # Value restriction: only types in the list are valid as values. If the - # list is empty, there is no restriction. - values = None # type: List[mypy.types.Type] # Upper bound: only subtypes of upper_bound are valid as values. By default # this is 'object', meaning no restriction. upper_bound = None # type: mypy.types.Type @@ -2069,14 +2056,12 @@ class TypeVarExpr(SymbolNode, Expression): # variable. variance = INVARIANT - def __init__(self, name: str, fullname: str, - values: List['mypy.types.Type'], - upper_bound: 'mypy.types.Type', - variance: int = INVARIANT) -> None: + def __init__( + self, name: str, fullname: str, upper_bound: 'mypy.types.Type', variance: int = INVARIANT + ) -> None: super().__init__() self._name = name self._fullname = fullname - self.values = values self.upper_bound = upper_bound self.variance = variance @@ -2088,6 +2073,29 @@ def name(self) -> str: def fullname(self) -> str: return self._fullname + +class TypeVarExpr(TypeVarLikeExpr): + """Type variable expression TypeVar(...). + + This is also used to represent type variables in symbol tables. + + A type variable is not valid as a type unless bound in a TypeVarLikeScope. + That happens within: + + 1. a generic class that uses the type variable as a type argument or + 2. a generic function that refers to the type variable in its signature. + """ + # Value restriction: only types in the list are valid as values. If the + # list is empty, there is no restriction. + values = None # type: List[mypy.types.Type] + + def __init__(self, name: str, fullname: str, + values: List['mypy.types.Type'], + upper_bound: 'mypy.types.Type', + variance: int = INVARIANT) -> None: + super().__init__(name, fullname, upper_bound, variance) + self.values = values + def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_type_var_expr(self) @@ -2110,6 +2118,30 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarExpr': data['variance']) +class ParamSpecExpr(TypeVarLikeExpr): + def accept(self, visitor: ExpressionVisitor[T]) -> T: + return visitor.visit_paramspec_var_expr(self) + + def serialize(self) -> JsonDict: + return { + '.class': 'ParamSpecExpr', + 'name': self._name, + 'fullname': self._fullname, + 'upper_bound': self.upper_bound.serialize(), + 'variance': self.variance, + } + + @classmethod + def deserialize(cls, data: JsonDict) -> 'ParamSpecExpr': + assert data['.class'] == 'ParamSpecExpr' + return ParamSpecExpr( + data['name'], + data['fullname'], + mypy.types.deserialize_type(data['upper_bound']), + data['variance'] + ) + + class TypeAliasExpr(Expression): """Type alias expression (rvalue).""" diff --git a/mypy/plugin.py b/mypy/plugin.py index ed2d80cfaf29..eb31878b62a7 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -126,7 +126,7 @@ class C: pass from mypy.nodes import ( Expression, Context, ClassDef, SymbolTableNode, MypyFile, CallExpr ) -from mypy.tvar_scope import TypeVarScope +from mypy.tvar_scope import TypeVarLikeScope from mypy.types import Type, Instance, CallableType, TypeList, UnboundType, ProperType from mypy.messages import MessageBuilder from mypy.options import Options @@ -265,7 +265,7 @@ def fail(self, msg: str, ctx: Context, serious: bool = False, *, @abstractmethod def anal_type(self, t: Type, *, - tvar_scope: Optional[TypeVarScope] = None, + tvar_scope: Optional[TypeVarLikeScope] = None, allow_tuple_literal: bool = False, allow_unbound_tvars: bool = False, report_invalid_types: bool = True, diff --git a/mypy/semanal.py b/mypy/semanal.py index 24c9cb7a9e5f..7c4f52eeeec9 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -76,8 +76,9 @@ nongen_builtins, get_member_expr_fullname, REVEAL_TYPE, REVEAL_LOCALS, is_final_node, TypedDictExpr, type_aliases_target_versions, EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr, + ParamSpecExpr ) -from mypy.tvar_scope import TypeVarScope +from mypy.tvar_scope import TypeVarLikeScope from mypy.typevars import fill_typevars from mypy.visitor import NodeVisitor from mypy.errors import Errors, report_internal_error @@ -96,7 +97,7 @@ from mypy.nodes import implicit_module_attrs from mypy.typeanal import ( TypeAnalyser, analyze_type_alias, no_subscript_builtin_alias, - TypeVariableQuery, TypeVarList, remove_dups, has_any_from_unimported_type, + TypeVarLikeQuery, TypeVarLikeList, remove_dups, has_any_from_unimported_type, check_for_explicit_any, type_constructors, fix_instance_types ) from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError @@ -160,7 +161,7 @@ class SemanticAnalyzer(NodeVisitor[None], # Stack of outer classes (the second tuple item contains tvars). type_stack = None # type: List[Optional[TypeInfo]] # Type variables bound by the current scope, be it class or function - tvar_scope = None # type: TypeVarScope + tvar_scope = None # type: TypeVarLikeScope # Per-module options options = None # type: Options @@ -234,7 +235,7 @@ def __init__(self, self.imports = set() self.type = None self.type_stack = [] - self.tvar_scope = TypeVarScope() + self.tvar_scope = TypeVarLikeScope() self.function_stack = [] self.block_depth = [0] self.loop_depth = 0 @@ -477,7 +478,7 @@ def file_context(self, self.is_stub_file = file_node.path.lower().endswith('.pyi') self._is_typeshed_stub_file = is_typeshed_file(file_node.path) self.globals = file_node.names - self.tvar_scope = TypeVarScope() + self.tvar_scope = TypeVarLikeScope() self.named_tuple_analyzer = NamedTupleAnalyzer(options, self) self.typed_dict_analyzer = TypedDictAnalyzer(options, self, self.msg) @@ -1211,7 +1212,7 @@ class Foo(Bar, Generic[T]): ... Returns (remaining base expressions, inferred type variables, is protocol). """ removed = [] # type: List[int] - declared_tvars = [] # type: TypeVarList + declared_tvars = [] # type: TypeVarLikeList is_protocol = False for i, base_expr in enumerate(base_type_exprs): self.analyze_type_expr(base_expr) @@ -1259,10 +1260,13 @@ class Foo(Bar, Generic[T]): ... tvar_defs = [] # type: List[TypeVarDef] for name, tvar_expr in declared_tvars: tvar_def = self.tvar_scope.bind_new(name, tvar_expr) + assert isinstance(tvar_def, TypeVarDef), ( + "mypy does not currently support ParamSpec use in generic classes" + ) tvar_defs.append(tvar_def) return base_type_exprs, tvar_defs, is_protocol - def analyze_class_typevar_declaration(self, base: Type) -> Optional[Tuple[TypeVarList, bool]]: + def analyze_class_typevar_declaration(self, base: Type) -> Optional[Tuple[TypeVarLikeList, bool]]: """Analyze type variables declared using Generic[...] or Protocol[...]. Args: @@ -1281,7 +1285,7 @@ def analyze_class_typevar_declaration(self, base: Type) -> Optional[Tuple[TypeVa sym.node.fullname == 'typing.Protocol' and base.args or sym.node.fullname == 'typing_extensions.Protocol' and base.args): is_proto = sym.node.fullname != 'typing.Generic' - tvars = [] # type: TypeVarList + tvars = [] # type: TypeVarLikeList for arg in unbound.args: tag = self.track_incomplete_refs() tvar = self.analyze_unbound_tvar(arg) @@ -1311,9 +1315,9 @@ def analyze_unbound_tvar(self, t: Type) -> Optional[Tuple[str, TypeVarExpr]]: def get_all_bases_tvars(self, base_type_exprs: List[Expression], - removed: List[int]) -> TypeVarList: + removed: List[int]) -> TypeVarLikeList: """Return all type variable references in bases.""" - tvars = [] # type: TypeVarList + tvars = [] # type: TypeVarLikeList for i, base_expr in enumerate(base_type_exprs): if i not in removed: try: @@ -1321,7 +1325,7 @@ def get_all_bases_tvars(self, except TypeTranslationError: # This error will be caught later. continue - base_tvars = base.accept(TypeVariableQuery(self.lookup_qualified, self.tvar_scope)) + base_tvars = base.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope)) tvars.extend(base_tvars) return remove_dups(tvars) @@ -1921,6 +1925,8 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: # * type variable definition elif self.process_typevar_declaration(s): special_form = True + elif self.process_paramspec_declaration(s): + special_form = True # * type constructors elif self.analyze_namedtuple_assign(s): special_form = True @@ -2400,7 +2406,7 @@ def analyze_alias(self, rvalue: Expression, typ = None # type: Optional[Type] if res: typ, depends_on = res - found_type_vars = typ.accept(TypeVariableQuery(self.lookup_qualified, self.tvar_scope)) + found_type_vars = typ.accept(TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope)) alias_tvars = [name for (name, node) in found_type_vars] qualified_tvars = [node.fullname for (name, node) in found_type_vars] else: @@ -2823,7 +2829,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: Return True if this looks like a type variable declaration (but maybe with errors), otherwise return False. """ - call = self.get_typevar_declaration(s) + call = self.get_typevarlike_declaration(s, "typing.TypeVar") if not call: return False @@ -2834,7 +2840,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: return False name = lvalue.name - if not self.check_typevar_name(call, name, s): + if not self.check_typevarlike_name(call, name, s): return False # Constraining types @@ -2894,24 +2900,31 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: self.add_symbol(name, call.analyzed, s) return True - def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> bool: + def check_typevarlike_name(self, call: CallExpr, name: str, context: Context) -> bool: + """Checks that the name of a TypeVar or ParamSpec matches its variable.""" name = unmangle(name) + assert isinstance(call.callee, RefExpr) + typevarlike_type = ( + call.callee.name if isinstance(call.callee, NameExpr) else call.callee.fullname + ) if len(call.args) < 1: - self.fail("Too few arguments for TypeVar()", context) + self.fail("Too few arguments for {}()".format(typevarlike_type), context) return False if (not isinstance(call.args[0], (StrExpr, BytesExpr, UnicodeExpr)) or not call.arg_kinds[0] == ARG_POS): - self.fail("TypeVar() expects a string literal as first argument", context) + self.fail("{}() expects a string literal as first argument".format(typevarlike_type), + context) return False elif call.args[0].value != name: - msg = "String argument 1 '{}' to TypeVar(...) does not match variable name '{}'" - self.fail(msg.format(call.args[0].value, name), context) + msg = "String argument 1 '{}' to {}(...) does not match variable name '{}'" + self.fail(msg.format(call.args[0].value, typevarlike_type, name), context) return False return True - def get_typevar_declaration(self, s: AssignmentStmt) -> Optional[CallExpr]: - """Returns the TypeVar() call expression if `s` is a type var declaration - or None otherwise. + def get_typevarlike_declaration(self, s: AssignmentStmt, + typevarlike_type: str) -> Optional[CallExpr]: + """Returns the call expression if `s` is a declaration of `typevarlike_type` + (TypeVar or ParamSpec), or None otherwise. """ if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): return None @@ -2921,7 +2934,7 @@ def get_typevar_declaration(self, s: AssignmentStmt) -> Optional[CallExpr]: callee = call.callee if not isinstance(callee, RefExpr): return None - if callee.fullname != 'typing.TypeVar': + if callee.fullname != typevarlike_type: return None return call @@ -3008,6 +3021,41 @@ def process_typevar_parameters(self, args: List[Expression], variance = INVARIANT return variance, upper_bound + def process_paramspec_declaration(self, s: AssignmentStmt) -> bool: + """Checks if s declares a ParamSpec; if yes, store it in symbol table. + + Return True if this looks like a parameter specification declaration (but maybe + with errors), otherwise return False. + + In the future, ParamSpec may accept bounds and variance arguments, in which + case more aggressive sharing of code with process_typevar_declaration should be pursued. + + """ + call = self.get_typevarlike_declaration(s, "mypy_extensions.ParamSpec") + if not call: + return False + + lvalue = s.lvalues[0] + assert isinstance(lvalue, NameExpr) + if s.type: + self.fail("Cannot declare the type of a parameter specification", s) + return False + + name = lvalue.name + if not self.check_typevarlike_name(call, name, s): + return False + + # PEP 612 reserves the right to define bound, covariant and contravariant arguments to + # ParamSpec in a later PEP. If and when that happens, we should do something + # on the lines of process_typevar_parameters + paramspec_var = ParamSpecExpr( + name, self.qualified_name(name), self.object_type(), INVARIANT + ) + paramspec_var.line = call.line + call.analyzed = paramspec_var + self.add_symbol(name, call.analyzed, s) + return True + def basic_new_typeinfo(self, name: str, basetype_or_fallback: Instance) -> TypeInfo: class_def = ClassDef(name, Block([])) if self.is_func_scope() and not self.type: @@ -4397,7 +4445,7 @@ def add_unknown_imported_symbol(self, # @contextmanager - def tvar_scope_frame(self, frame: TypeVarScope) -> Iterator[None]: + def tvar_scope_frame(self, frame: TypeVarLikeScope) -> Iterator[None]: old_scope = self.tvar_scope self.tvar_scope = frame yield @@ -4721,11 +4769,11 @@ def analyze_type_expr(self, expr: Expression) -> None: # them semantically analyzed, however, if they need to treat it as an expression # and not a type. (Which is to say, mypyc needs to do this.) Do the analysis # in a fresh tvar scope in order to suppress any errors about using type variables. - with self.tvar_scope_frame(TypeVarScope()): + with self.tvar_scope_frame(TypeVarLikeScope()): expr.accept(self) def type_analyzer(self, *, - tvar_scope: Optional[TypeVarScope] = None, + tvar_scope: Optional[TypeVarLikeScope] = None, allow_tuple_literal: bool = False, allow_unbound_tvars: bool = False, allow_placeholder: bool = False, @@ -4748,7 +4796,7 @@ def type_analyzer(self, *, def anal_type(self, typ: Type, *, - tvar_scope: Optional[TypeVarScope] = None, + tvar_scope: Optional[TypeVarLikeScope] = None, allow_tuple_literal: bool = False, allow_unbound_tvars: bool = False, allow_placeholder: bool = False, diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index ba0972e8c302..3fd2498c696f 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -14,7 +14,7 @@ from mypy.types import ( Type, FunctionLike, Instance, TupleType, TPDICT_FB_NAMES, ProperType, get_proper_type ) -from mypy.tvar_scope import TypeVarScope +from mypy.tvar_scope import TypeVarLikeScope from mypy.errorcodes import ErrorCode from mypy import join @@ -105,7 +105,7 @@ def accept(self, node: Node) -> None: @abstractmethod def anal_type(self, t: Type, *, - tvar_scope: Optional[TypeVarScope] = None, + tvar_scope: Optional[TypeVarLikeScope] = None, allow_tuple_literal: bool = False, allow_unbound_tvars: bool = False, report_invalid_types: bool = True) -> Optional[Type]: diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 49a85861b6e3..0c7fba2a2b5a 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -89,6 +89,7 @@ 'check-reports.test', 'check-errorcodes.test', 'check-annotated.test', + 'check-parameter-specification.test', ] # Tests that use Python 3.8-only AST features (like expression-scoped ignores): diff --git a/mypy/tvar_scope.py b/mypy/tvar_scope.py index 7d5dce18fc66..e00bbe2bffe1 100644 --- a/mypy/tvar_scope.py +++ b/mypy/tvar_scope.py @@ -1,16 +1,20 @@ from typing import Optional, Dict, Union -from mypy.types import TypeVarDef -from mypy.nodes import TypeVarExpr, SymbolTableNode +from mypy.types import TypeVarLikeDef, TypeVarDef, ParamSpecDef +from mypy.nodes import ParamSpecExpr, TypeVarExpr, TypeVarLikeExpr, SymbolTableNode -class TypeVarScope: - """Scope that holds bindings for type variables. Node fullname -> TypeVarDef.""" +class TypeVarLikeScope: + """Scope that holds bindings for type variables and parameter specifications. + + Node fullname -> TypeVarLikeDef. + + """ def __init__(self, - parent: 'Optional[TypeVarScope]' = None, + parent: 'Optional[TypeVarLikeScope]' = None, is_class_scope: bool = False, - prohibited: 'Optional[TypeVarScope]' = None) -> None: - """Initializer for TypeVarScope + prohibited: 'Optional[TypeVarLikeScope]' = None) -> None: + """Initializer for TypeVarLikeScope Parameters: parent: the outer scope for this scope @@ -18,7 +22,7 @@ def __init__(self, prohibited: Type variables that aren't strictly in scope exactly, but can't be bound because they're part of an outer class's scope. """ - self.scope = {} # type: Dict[str, TypeVarDef] + self.scope = {} # type: Dict[str, TypeVarLikeDef] self.parent = parent self.func_id = 0 self.class_id = 0 @@ -28,9 +32,9 @@ def __init__(self, self.func_id = parent.func_id self.class_id = parent.class_id - def get_function_scope(self) -> 'Optional[TypeVarScope]': + def get_function_scope(self) -> 'Optional[TypeVarLikeScope]': """Get the nearest parent that's a function scope, not a class scope""" - it = self # type: Optional[TypeVarScope] + it = self # type: Optional[TypeVarLikeScope] while it is not None and it.is_class_scope: it = it.parent return it @@ -44,36 +48,49 @@ def allow_binding(self, fullname: str) -> bool: return False return True - def method_frame(self) -> 'TypeVarScope': + def method_frame(self) -> 'TypeVarLikeScope': """A new scope frame for binding a method""" - return TypeVarScope(self, False, None) + return TypeVarLikeScope(self, False, None) - def class_frame(self) -> 'TypeVarScope': + def class_frame(self) -> 'TypeVarLikeScope': """A new scope frame for binding a class. Prohibits *this* class's tvars""" - return TypeVarScope(self.get_function_scope(), True, self) + return TypeVarLikeScope(self.get_function_scope(), True, self) - def bind_new(self, name: str, tvar_expr: TypeVarExpr) -> TypeVarDef: + def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeDef: if self.is_class_scope: self.class_id += 1 i = self.class_id else: self.func_id -= 1 i = self.func_id - tvar_def = TypeVarDef(name, - tvar_expr.fullname, - i, - values=tvar_expr.values, - upper_bound=tvar_expr.upper_bound, - variance=tvar_expr.variance, - line=tvar_expr.line, - column=tvar_expr.column) + if isinstance(tvar_expr, TypeVarExpr): + tvar_def = TypeVarDef( + name, + tvar_expr.fullname, + i, + values=tvar_expr.values, + upper_bound=tvar_expr.upper_bound, + variance=tvar_expr.variance, + line=tvar_expr.line, + column=tvar_expr.column + ) # type: TypeVarLikeDef + elif isinstance(tvar_expr, ParamSpecExpr): + tvar_def = ParamSpecDef( + name, + tvar_expr.fullname, + i, + line=tvar_expr.line, + column=tvar_expr.column + ) + else: + assert False self.scope[tvar_expr.fullname] = tvar_def return tvar_def - def bind_existing(self, tvar_def: TypeVarDef) -> None: + def bind_existing(self, tvar_def: TypeVarLikeDef) -> None: self.scope[tvar_def.fullname] = tvar_def - def get_binding(self, item: Union[str, SymbolTableNode]) -> Optional[TypeVarDef]: + def get_binding(self, item: Union[str, SymbolTableNode]) -> Optional[TypeVarLikeDef]: fullname = item.fullname if isinstance(item, SymbolTableNode) else item assert fullname is not None if fullname in self.scope: diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 905f46a92576..10f3fd8f3698 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -13,7 +13,7 @@ from abc import abstractmethod from mypy.ordered_dict import OrderedDict -from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional, Set +from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional, Set, Sequence from mypy_extensions import trait T = TypeVar('T') @@ -21,9 +21,9 @@ from mypy.types import ( Type, AnyType, CallableType, Overloaded, TupleType, TypedDictType, LiteralType, RawExpressionType, Instance, NoneType, TypeType, - UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, + UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef, TypeVarLikeDef, UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, - PlaceholderType, TypeAliasType, get_proper_type + PlaceholderType, TypeAliasType, ParamSpecType, get_proper_type ) @@ -62,6 +62,10 @@ def visit_deleted_type(self, t: DeletedType) -> T: def visit_type_var(self, t: TypeVarType) -> T: pass + @abstractmethod + def visit_param_spec(self, t: ParamSpecType) -> T: + pass + @abstractmethod def visit_instance(self, t: Instance) -> T: pass @@ -219,7 +223,7 @@ def translate_types(self, types: Iterable[Type]) -> List[Type]: return [t.accept(self) for t in types] def translate_variables(self, - variables: List[TypeVarDef]) -> List[TypeVarDef]: + variables: Sequence[TypeVarLikeDef]) -> Sequence[TypeVarLikeDef]: return variables def visit_overloaded(self, t: Overloaded) -> Type: diff --git a/mypy/typeanal.py b/mypy/typeanal.py index f1a96eacd23e..a23057aeac3e 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -16,17 +16,17 @@ CallableType, NoneType, ErasedType, DeletedType, TypeList, TypeVarDef, SyntheticTypeVisitor, StarType, PartialType, EllipsisType, UninhabitedType, TypeType, CallableArgument, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, - PlaceholderType, Overloaded, get_proper_type, TypeAliasType + PlaceholderType, Overloaded, get_proper_type, TypeAliasType, TypeVarLikeDef, ParamSpecDef ) from mypy.nodes import ( TypeInfo, Context, SymbolTableNode, Var, Expression, nongen_builtins, check_arg_names, check_arg_kinds, ARG_POS, ARG_NAMED, - ARG_OPT, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, TypeVarExpr, + ARG_OPT, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, TypeVarExpr, TypeVarLikeExpr, ParamSpecExpr, TypeAlias, PlaceholderNode, SYMBOL_FUNCBASE_TYPES, Decorator, MypyFile ) from mypy.typetraverser import TypeTraverserVisitor -from mypy.tvar_scope import TypeVarScope +from mypy.tvar_scope import TypeVarLikeScope from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.plugin import Plugin, TypeAnalyzerPluginInterface, AnalyzeTypeContext from mypy.semanal_shared import SemanticAnalyzerCoreInterface @@ -64,7 +64,7 @@ def analyze_type_alias(node: Expression, api: SemanticAnalyzerCoreInterface, - tvar_scope: TypeVarScope, + tvar_scope: TypeVarLikeScope, plugin: Plugin, options: Options, is_typeshed_stub: bool, @@ -117,7 +117,7 @@ class TypeAnalyser(SyntheticTypeVisitor[Type], TypeAnalyzerPluginInterface): def __init__(self, api: SemanticAnalyzerCoreInterface, - tvar_scope: TypeVarScope, + tvar_scope: TypeVarLikeScope, plugin: Plugin, options: Options, is_typeshed_stub: bool, *, @@ -200,11 +200,18 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) self.fail(no_subscript_builtin_alias(fullname, propose_alt=not self.defining_alias), t) tvar_def = self.tvar_scope.get_binding(sym) + if isinstance(sym.node, ParamSpecExpr): + if tvar_def is None: + self.fail('ParamSpec "{}" is unbound'.format(t.name), t) + return AnyType(TypeOfAny.from_error) + self.fail('Invalid location for ParamSpec "{}"'.format(t.name), t) + return AnyType(TypeOfAny.from_error) if isinstance(sym.node, TypeVarExpr) and tvar_def is not None and self.defining_alias: self.fail('Can\'t use bound type variable "{}"' ' to define generic alias'.format(t.name), t) return AnyType(TypeOfAny.from_error) if isinstance(sym.node, TypeVarExpr) and tvar_def is not None: + assert isinstance(tvar_def, TypeVarDef) if len(t.args) > 0: self.fail('Type variable "{}" used with arguments'.format(t.name), t) return TypeVarType(tvar_def, t.line) @@ -641,7 +648,16 @@ def analyze_callable_type(self, t: UnboundType) -> Type: fallback=fallback, is_ellipsis_args=True) else: - self.fail('The first argument to Callable must be a list of types or "..."', t) + args = t.args[0] + sym = self.lookup_qualified(args.name, args) + tvar_def = self.tvar_scope.get_binding(sym) + if not isinstance(tvar_def, ParamSpecDef): + self.fail('The first argument to Callable must be a list of types or "..."', t) + return AnyType(TypeOfAny.from_error) + + # ret = CallableType() + # TODO(shantanu): make + # return ParamSpecType(tvar_def, t.line) return AnyType(TypeOfAny.from_error) else: self.fail('Please use "Callable[[], ]" or "Callable"', t) @@ -791,13 +807,14 @@ def tvar_scope_frame(self) -> Iterator[None]: self.tvar_scope = old_scope def infer_type_variables(self, - type: CallableType) -> List[Tuple[str, TypeVarExpr]]: + type: CallableType) -> List[Tuple[str, TypeVarLikeExpr]]: """Return list of unique type variables referred to in a callable.""" names = [] # type: List[str] - tvars = [] # type: List[TypeVarExpr] + tvars = [] # type: List[TypeVarLikeExpr] for arg in type.arg_types: - for name, tvar_expr in arg.accept(TypeVariableQuery(self.lookup_qualified, - self.tvar_scope)): + for name, tvar_expr in arg.accept( + TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope) + ): if name not in names: names.append(name) tvars.append(tvar_expr) @@ -806,29 +823,30 @@ def infer_type_variables(self, # functions in the return type belong to those functions, not the # function we're currently analyzing. for name, tvar_expr in type.ret_type.accept( - TypeVariableQuery(self.lookup_qualified, self.tvar_scope, - include_callables=False)): + TypeVarLikeQuery(self.lookup_qualified, self.tvar_scope, include_callables=False) + ): if name not in names: names.append(name) tvars.append(tvar_expr) return list(zip(names, tvars)) - def bind_function_type_variables(self, - fun_type: CallableType, defn: Context) -> List[TypeVarDef]: + def bind_function_type_variables( + self, fun_type: CallableType, defn: Context + ) -> Sequence[TypeVarLikeDef]: """Find the type variables of the function type and bind them in our tvar_scope""" if fun_type.variables: for var in fun_type.variables: var_node = self.lookup_qualified(var.name, defn) assert var_node, "Binding for function type variable not found within function" var_expr = var_node.node - assert isinstance(var_expr, TypeVarExpr) + assert isinstance(var_expr, TypeVarLikeExpr) self.tvar_scope.bind_new(var.name, var_expr) return fun_type.variables typevars = self.infer_type_variables(fun_type) # Do not define a new type variable if already defined in scope. typevars = [(name, tvar) for name, tvar in typevars if not self.is_defined_type_var(name, defn)] - defs = [] # type: List[TypeVarDef] + defs = [] # type: List[TypeVarLikeDef] for name, tvar in typevars: if not self.tvar_scope.allow_binding(tvar.fullname): self.fail("Type variable '{}' is bound by an outer class".format(name), defn) @@ -860,17 +878,22 @@ def anal_type(self, t: Type, nested: bool = True) -> Type: if nested: self.nesting_level -= 1 - def anal_var_defs(self, var_defs: List[TypeVarDef]) -> List[TypeVarDef]: - a = [] # type: List[TypeVarDef] - for vd in var_defs: - a.append(TypeVarDef(vd.name, - vd.fullname, - vd.id.raw_id, - self.anal_array(vd.values), - vd.upper_bound.accept(self), - vd.variance, - vd.line)) - return a + def anal_var_def(self, var_def: TypeVarLikeDef) -> TypeVarLikeDef: + if isinstance(var_def, TypeVarDef): + return TypeVarDef( + var_def.name, + var_def.fullname, + var_def.id.raw_id, + self.anal_array(var_def.values), + var_def.upper_bound.accept(self), + var_def.variance, + var_def.line + ) + else: + return var_def + + def anal_var_defs(self, var_defs: Sequence[TypeVarLikeDef]) -> List[TypeVarLikeDef]: + return [self.anal_var_def(vd) for vd in var_defs] def named_type_with_normalized_str(self, fully_qualified_name: str) -> Instance: """Does almost the same thing as `named_type`, except that we immediately @@ -898,7 +921,7 @@ def tuple_type(self, items: List[Type]) -> TupleType: return TupleType(items, fallback=self.named_type('builtins.tuple', [any_type])) -TypeVarList = List[Tuple[str, TypeVarExpr]] +TypeVarLikeList = List[Tuple[str, TypeVarLikeExpr]] # Mypyc doesn't support callback protocols yet. MsgCallback = Callable[[str, Context, DefaultNamedArg(Optional[ErrorCode], 'code')], None] @@ -1059,11 +1082,11 @@ def flatten_tvars(ll: Iterable[List[T]]) -> List[T]: return remove_dups(chain.from_iterable(ll)) -class TypeVariableQuery(TypeQuery[TypeVarList]): +class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]): def __init__(self, lookup: Callable[[str, Context], Optional[SymbolTableNode]], - scope: 'TypeVarScope', + scope: 'TypeVarLikeScope', *, include_callables: bool = True, include_bound_tvars: bool = False) -> None: @@ -1080,12 +1103,12 @@ def _seems_like_callable(self, type: UnboundType) -> bool: return True return False - def visit_unbound_type(self, t: UnboundType) -> TypeVarList: + def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList: name = t.name node = self.lookup(name, t) - if node and isinstance(node.node, TypeVarExpr) and ( + if node and isinstance(node.node, TypeVarLikeExpr) and ( self.include_bound_tvars or self.scope.get_binding(node) is None): - assert isinstance(node.node, TypeVarExpr) + assert isinstance(node.node, TypeVarLikeExpr) return [(name, node.node)] elif not self.include_callables and self._seems_like_callable(t): return [] @@ -1094,7 +1117,7 @@ def visit_unbound_type(self, t: UnboundType) -> TypeVarList: else: return super().visit_unbound_type(t) - def visit_callable_type(self, t: CallableType) -> TypeVarList: + def visit_callable_type(self, t: CallableType) -> TypeVarLikeList: if self.include_callables: return super().visit_callable_type(t) else: diff --git a/mypy/typeops.py b/mypy/typeops.py index a31c07ae74a2..6824e7d45e7e 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -10,7 +10,7 @@ import sys from mypy.types import ( - TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, Overloaded, + TupleType, Instance, FunctionLike, Type, CallableType, TypeVarDef, TypeVarLikeDef, Overloaded, TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, TypedDictType, AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, copy_type, TypeAliasType, TypeQuery @@ -113,7 +113,7 @@ def class_callable(init_type: CallableType, info: TypeInfo, type_type: Instance, special_sig: Optional[str], is_new: bool, orig_self_type: Optional[Type] = None) -> CallableType: """Create a type object type based on the signature of __init__.""" - variables = [] # type: List[TypeVarDef] + variables = [] # type: List[TypeVarLikeDef] variables.extend(info.defn.type_vars) variables.extend(init_type.variables) @@ -227,6 +227,8 @@ class B(A): pass # TODO: infer bounds on the type of *args? return cast(F, func) self_param_type = get_proper_type(func.arg_types[0]) + + variables = [] # type: Sequence[TypeVarLikeDef] if func.variables and supported_self_type(self_param_type): if original_type is None: # TODO: type check method override (see #7861). diff --git a/mypy/types.py b/mypy/types.py index 98943e374e48..1fdb45643e3c 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -328,12 +328,37 @@ def is_meta_var(self) -> bool: return self.meta_level > 0 -class TypeVarDef(mypy.nodes.Context): - """Definition of a single type variable.""" - +class TypeVarLikeDef(mypy.nodes.Context): name = '' # Name (may be qualified) fullname = '' # Fully qualified name id = None # type: TypeVarId + + def __init__( + self, name: str, fullname: str, id: Union[TypeVarId, int], line: int = -1, column: int = -1 + ) -> None: + super().__init__(line, column) + self.name = name + self.fullname = fullname + if isinstance(id, int): + id = TypeVarId(id) + self.id = id + + def __repr__(self) -> str: + return self.name + + def new_unification_variable(self) -> 'TypeVarLikeDef': + raise NotImplementedError + + def serialize(self) -> JsonDict: + raise NotImplementedError + + @classmethod + def deserialize(cls, data: JsonDict) -> 'TypeVarLikeDef': + raise NotImplementedError + + +class TypeVarDef(TypeVarLikeDef): + """Definition of a single type variable.""" values = None # type: List[Type] # Value restriction, empty list if no restriction upper_bound = None # type: Type variance = INVARIANT # type: int @@ -341,22 +366,16 @@ class TypeVarDef(mypy.nodes.Context): def __init__(self, name: str, fullname: str, id: Union[TypeVarId, int], values: List[Type], upper_bound: Type, variance: int = INVARIANT, line: int = -1, column: int = -1) -> None: - super().__init__(line, column) + super().__init__(name, fullname, id, line, column) assert values is not None, "No restrictions must be represented by empty list" - self.name = name - self.fullname = fullname - if isinstance(id, int): - id = TypeVarId(id) - self.id = id self.values = values self.upper_bound = upper_bound self.variance = variance - @staticmethod - def new_unification_variable(old: 'TypeVarDef') -> 'TypeVarDef': + def new_unification_variable(self) -> 'TypeVarDef': new_id = TypeVarId.new(meta_level=1) - return TypeVarDef(old.name, old.fullname, new_id, old.values, - old.upper_bound, old.variance, old.line, old.column) + return TypeVarDef(self.name, self.fullname, new_id, self.values, + self.upper_bound, self.variance, self.line, self.column) def __repr__(self) -> str: if self.values: @@ -389,6 +408,32 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarDef': ) +class ParamSpecDef(TypeVarLikeDef): + """Definition of a single ParamSpec variable.""" + + def new_unification_variable(self) -> 'ParamSpecDef': + new_id = TypeVarId.new(meta_level=1) + return ParamSpecDef(self.name, self.fullname, new_id, line=self.line, column=self.column) + + def serialize(self) -> JsonDict: + assert not self.id.is_meta_var() + return { + '.class': 'ParamSpecDef', + 'name': self.name, + 'fullname': self.fullname, + 'id': self.id.raw_id, + } + + @classmethod + def deserialize(cls, data: JsonDict) -> 'ParamSpecDef': + assert data['.class'] == 'ParamSpecDef' + return ParamSpecDef( + data['name'], + data['fullname'], + data['id'], + ) + + class UnboundType(ProperType): """Instance type that has not been bound during semantic analysis.""" @@ -906,6 +951,44 @@ def deserialize(cls, data: JsonDict) -> 'TypeVarType': return TypeVarType(tvdef) +class ParamSpecType(ProperType): + """A parameter specification type.""" + + __slots__ = ('name', 'fullname', 'id') + + def __init__(self, binder: ParamSpecDef, line: int = -1, column: int = -1) -> None: + super().__init__(line, column) + self.name = binder.name # Name of the ParamSpec (for messages and debugging) + self.fullname = binder.fullname # type: str + self.id = binder.id # type: TypeVarId + + def accept(self, visitor: 'TypeVisitor[T]') -> T: + return visitor.visit_param_spec(self) + + def __hash__(self) -> int: + return hash(self.id) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ParamSpecType): + return NotImplemented + return self.id == other.id + + def serialize(self) -> JsonDict: + assert not self.id.is_meta_var() + return { + '.class': 'ParamSpecType', + 'name': self.name, + 'fullname': self.fullname, + 'id': self.id.raw_id, + } + + @classmethod + def deserialize(cls, data: JsonDict) -> 'ParamSpecType': + assert data['.class'] == 'ParamSpecType' + tvdef = ParamSpecDef(data['name'], data['fullname'], data['id']) + return ParamSpecType(tvdef) + + class FunctionLike(ProperType): """Abstract base class for function types.""" @@ -976,7 +1059,7 @@ def __init__(self, fallback: Instance, name: Optional[str] = None, definition: Optional[SymbolNode] = None, - variables: Optional[List[TypeVarDef]] = None, + variables: Optional[Sequence[TypeVarLikeDef]] = None, line: int = -1, column: int = -1, is_ellipsis_args: bool = False, @@ -1028,7 +1111,7 @@ def copy_modified(self, fallback: Bogus[Instance] = _dummy, name: Bogus[Optional[str]] = _dummy, definition: Bogus[SymbolNode] = _dummy, - variables: Bogus[List[TypeVarDef]] = _dummy, + variables: Bogus[Sequence[TypeVarLikeDef]] = _dummy, line: Bogus[int] = _dummy, column: Bogus[int] = _dummy, is_ellipsis_args: Bogus[bool] = _dummy, @@ -2030,6 +2113,15 @@ def visit_type_var(self, t: TypeVarType) -> str: s += '(upper_bound={})'.format(t.upper_bound.accept(self)) return s + def visit_param_spec(self, t: ParamSpecType) -> str: + if t.name is None: + # Anonymous param spec type (only numeric id). + s = '`{}'.format(t.id) + else: + # Named param spec type. + s = '{}`{}'.format(t.name, t.id) + return s + def visit_callable_type(self, t: CallableType) -> str: s = '' bare_asterisk = False @@ -2057,15 +2149,19 @@ def visit_callable_type(self, t: CallableType) -> str: if t.variables: vs = [] - # We reimplement TypeVarDef.__repr__ here in order to support id_mapper. for var in t.variables: - if var.values: - vals = '({})'.format(', '.join(val.accept(self) for val in var.values)) - vs.append('{} in {}'.format(var.name, vals)) - elif not is_named_instance(var.upper_bound, 'builtins.object'): - vs.append('{} <: {}'.format(var.name, var.upper_bound.accept(self))) + if isinstance(var, TypeVarDef): + # We reimplement TypeVarDef.__repr__ here in order to support id_mapper. + if var.values: + vals = '({})'.format(', '.join(val.accept(self) for val in var.values)) + vs.append('{} in {}'.format(var.name, vals)) + elif not is_named_instance(var.upper_bound, 'builtins.object'): + vs.append('{} <: {}'.format(var.name, var.upper_bound.accept(self))) + else: + vs.append(var.name) else: - vs.append(var.name) + # For other TypeVarLikeDefs, just use the repr + vs.append(repr(var)) s = '{} {}'.format('[{}]'.format(', '.join(vs)), s) return 'def {}'.format(s) diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index 8d7459f7a551..a91309122870 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -6,7 +6,7 @@ Type, SyntheticTypeVisitor, AnyType, UninhabitedType, NoneType, ErasedType, DeletedType, TypeVarType, LiteralType, Instance, CallableType, TupleType, TypedDictType, UnionType, Overloaded, TypeType, CallableArgument, UnboundType, TypeList, StarType, EllipsisType, - PlaceholderType, PartialType, RawExpressionType, TypeAliasType + PlaceholderType, PartialType, RawExpressionType, TypeAliasType, ParamSpecType ) @@ -37,6 +37,9 @@ def visit_type_var(self, t: TypeVarType) -> None: # definition. We want to traverse everything just once. pass + def visit_param_spec(self, t: ParamSpecType) -> None: + pass + def visit_literal_type(self, t: LiteralType) -> None: t.fallback.accept(self) diff --git a/mypy/visitor.py b/mypy/visitor.py index d692142e6bcc..0093c02ca44c 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -155,6 +155,10 @@ def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> T: def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> T: pass + @abstractmethod + def visit_paramspec_var_expr(self, o: 'mypy.nodes.ParamSpecExpr') -> T: + pass + @abstractmethod def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T: pass @@ -529,6 +533,9 @@ def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> T: def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> T: pass + def visit_paramspec_var_expr(self, o: 'mypy.nodes.ParamSpecExpr') -> T: + pass + def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T: pass diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test new file mode 100644 index 000000000000..dc9d2b7bc24a --- /dev/null +++ b/test-data/unit/check-parameter-specification.test @@ -0,0 +1,150 @@ +[case testBasicParamSpec] +from mypy_extensions import ParamSpec +P = ParamSpec('P') +[builtins fixtures/tuple.pyi] + +[case testParamSpecLocations] +from typing import Callable, List +from mypy_extensions import ParamSpec, Concatenate +P = ParamSpec('P') + +x: P # E: ParamSpec "P" is unbound +def foo1(x: Callable[P, int]) -> Callable[P, str]: ... +def foo2(x: P) -> P: ... # E: Invalid location for ParamSpec "P" +# TODO(shantanu): uncomment +# def foo3(x: Concatenate[int, P]) -> int: ... $ E: Invalid location for Concatenate +def foo4(x: List[P]) -> None: ... # E: Invalid location for ParamSpec "P" +def foo5(x: Callable[[int, str], P]) -> None: ... # E: Invalid location for ParamSpec "P" +def foo6(x: Callable[[P], int]) -> None: ... # E: Invalid location for ParamSpec "P" +[builtins fixtures/tuple.pyi] + +[case testParamSpecClasses] +from typing import Callable, Generic, TypeVar +from mypy_extensions import ParamSpec, Concatenate +P = ParamSpec('P') + +T = TypeVar("T") +S = TypeVar("S") +P_2 = ParamSpec("P_2") + +class X(Generic[T, P]): + f: Callable[P, int] + x: T + +def foo1(x: X[int, P_2]) -> str: ... # Accepted +def foo2(x: X[int, Concatenate[int, P_2]]) -> str: ... # Accepted +def foo3(x: X[int, [int, bool]]) -> str: ... # Accepted +def foo4(x: X[int, ...]) -> str: ... # Accepted +def foo5(x: X[int, int]) -> str: ... # Rejected + +class Z(Generic[P]): + f: Callable[P, int] + +def foo6(x: Z[[int, str, bool]]) -> str: ... # Accepted +def foo7(x: Z[int, str, bool]) -> str: ... # Accepted +[builtins fixtures/tuple.pyi] + +[case testParamSpecSemantics] +from typing import Callable +from mypy_extensions import ParamSpec +P = ParamSpec('P') + +def changes_return_type_to_str(x: Callable[P, int]) -> Callable[P, str]: ... +def returns_int(a: str, b: bool) -> int: ... + +f = changes_return_type_to_str(returns_int) +reveal_type(f) # f should have the type:(a: str, b: bool) -> str + +f("A", True) # Accepted +f(a="A", b=True) # Accepted +f("A", "A") # Rejected + +def expects_str(x: str): ... +def expects_int(x: int): ... + +expects_str(f("A", True)) # Accepted +expects_int(f("A", True)) # Rejected +[builtins fixtures/tuple.pyi] + +[case testParamSpecSemanticsMore] +from typing import Callable +from mypy_extensions import ParamSpec +P = ParamSpec("P") + +def foo(x: Callable[P, int], y: Callable[P, int]) -> Callable[P, bool]: ... + +def x_int_y_str(x: int, y: str) -> int: ... +def y_int_x_str(y: int, x: str) -> int: ... + +f1 = foo(x_int_y_str, x_int_y_str) # Should return (x: int, y: str) -> bool +reveal_type(f1) + +f2 = foo(x_int_y_str, y_int_x_str) # Could return (__a: int, __b: str) -> int +reveal_type(f2) + +def keyword_only_x(*, x: int) -> int: ... +def keyword_only_y(*, y: int) -> int: ... +foo(keyword_only_x, keyword_only_y) # Rejected +[builtins fixtures/tuple.pyi] + + +[case testParamSpecConstructor] +from typing import Callable, TypeVar +from mypy_extensions import ParamSpec +P = ParamSpec("P") +U = TypeVar("U") + +class Y(Generic[U, P]): + f: Callable[P, str] + prop: U + + def __init__(self, f: Callable[P, str], prop: U) -> None: + self.f = f + self.prop = prop + +def a(q: int) -> str: ... + +reveal_type(Y(a, 1)) # Should resolve to Y[(q: int), int] +reveal_type(Y(a, 1).f) # Should resolve to (q: int) -> str + +[builtins fixtures/tuple.pyi] + + +[case testConcatenate] +from typing import Callable, TypeVar +from mypy_extensions import ParamSpec, Concatenate +P = ParamSpec("P") +U = TypeVar("T") + +def bar(x: int, *args: bool) -> int: ... + +def add(x: Callable[P, int]) -> Callable[Concatenate[str, P], bool]: ... +reveal_type(add(bar)) # Should return (__a: str, x: int, *args: bool) -> bool + +def remove(x: Callable[Concatenate[int, P], int]) -> Callable[P, bool]: ... +reveal_type(remove(bar)) # Should return (*args: bool) -> bool + +def transform( + x: Callable[Concatenate[int, P], int] +) -> Callable[Concatenate[str, P], bool]: ... + +reveal_type(transform(bar)) # Should return (__a: str, *args: bool) -> bool + +def expects_int_first(x: Callable[Concatenate[int, P], int]) -> None: ... + +@expects_int_first # Rejected +def one(x: str) -> int: ... + +@expects_int_first # Rejected +def two(*, x: int) -> int: ... + +@expects_int_first # Rejected +def three(**kwargs: int) -> int: ... + +@expects_int_first # Accepted +def four(*args: int) -> int: ... + +[builtins fixtures/tuple.pyi] + + +# TODO: tests from https://www.python.org/dev/peps/pep-0612/#id15 "The components of a ParamSpec" and on diff --git a/test-data/unit/lib-stub/mypy_extensions.pyi b/test-data/unit/lib-stub/mypy_extensions.pyi index 306d217f478e..f8f297c8ba1e 100644 --- a/test-data/unit/lib-stub/mypy_extensions.pyi +++ b/test-data/unit/lib-stub/mypy_extensions.pyi @@ -48,3 +48,7 @@ def trait(cls: Any) -> Any: ... mypyc_attr: Any class FlexibleAlias(Generic[_T, _U]): ... + +# Special cased in the type checker, so can be initialised to anything +ParamSpec = 0 +Concatenate = 0 diff --git a/test-data/unit/semanal-errors.test b/test-data/unit/semanal-errors.test index afd39122f99e..098fb065cb48 100644 --- a/test-data/unit/semanal-errors.test +++ b/test-data/unit/semanal-errors.test @@ -1419,3 +1419,12 @@ def g() -> None: # N: (Hint: Use "T" in function signature to bind "T" inside a function) [builtins fixtures/dict.pyi] [out] + +[case testParamSpec] +from mypy_extensions import ParamSpec + +TParams = ParamSpec('TParams') +TP = ParamSpec('?') # E: String argument 1 '?' to ParamSpec(...) does not match variable name 'TP' +TP2: int = ParamSpec('TP2') # E: Cannot declare the type of a parameter specification + +[out]