Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constraints on ListVariadics #318

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions analysis/analysisError.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,16 @@ let rec messages ~concise ~signature location kind =
(Type.Variable expected)
name;
]
| InvalidTypeParameters
{ name; kind = AttributeResolution.ViolateConstraintsVariadic { expected; actual } } ->
[
Format.asprintf
"Type parameter list `%a` violates constraints on `%s` in generic type `%s`."
(Type.Record.OrderedTypes.pp_concise ~pp_type)
actual
(Type.Record.Variable.RecordVariadic.RecordList.name expected)
name;
]
| InvalidTypeParameters { name; kind = AttributeResolution.UnexpectedKind { expected; actual } }
->
let details =
Expand Down Expand Up @@ -3453,6 +3463,9 @@ let dequalify
actual = dequalify actual;
expected = Type.Variable.Unary.dequalify ~dequalify_map expected;
}
| AttributeResolution.ViolateConstraintsVariadic { actual; expected } ->
AttributeResolution.ViolateConstraintsVariadic
{ actual; expected = Type.Variable.Variadic.List.dequalify ~dequalify_map expected }
| AttributeResolution.UnexpectedKind { actual; expected } ->
AttributeResolution.UnexpectedKind
{ actual; expected = Type.Variable.dequalify dequalify_map expected }
Expand Down
31 changes: 27 additions & 4 deletions analysis/attributeResolution.ml
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ module TypeParameterValidationTypes = struct
actual: Type.t;
expected: Type.Variable.Unary.t;
}
| ViolateConstraintsVariadic of {
actual: Type.OrderedTypes.t;
expected: Type.Variable.Variadic.List.t;
}
| UnexpectedKind of {
actual: Type.Parameter.t;
expected: Type.Variable.t;
Expand Down Expand Up @@ -821,7 +825,7 @@ class base class_metadata_environment dependency =
TypeConstraints.empty
~order
~pair
>>| TypeOrder.OrderedConstraints.add_upper_bound ~order ~pair
>>= TypeOrder.OrderedConstraints.add_upper_bound ~order ~pair
|> Option.is_none
in
if invalid then
Expand All @@ -848,10 +852,29 @@ class base class_metadata_environment dependency =
( CallableParameters Undefined,
Some
{ name; kind = UnexpectedKind { expected = generic; actual = given } } )
| ParameterVariadic _, CallableParameters _
| ListVariadic _, Group _ ->
| ParameterVariadic _, CallableParameters _ -> given, None
| ListVariadic generic, Group given ->
(* TODO(T47346673): accept w/ new kind of validation *)
given, None
let invalid =
let order = self#full_order ~assumptions in
let pair = Type.Variable.ListVariadicPair (generic, given) in
TypeOrder.OrderedConstraints.add_lower_bound
TypeConstraints.empty
~order
~pair
>>= TypeOrder.OrderedConstraints.add_upper_bound ~order ~pair
|> Option.is_none
in
if invalid then
( Type.Parameter.Group Any,
Some
{
name;
kind =
ViolateConstraintsVariadic { actual = given; expected = generic };
} )
else
Type.Parameter.Group given, None
in
List.map paired ~f:check_parameter
|> List.unzip
Expand Down
4 changes: 4 additions & 0 deletions analysis/attributeResolution.mli
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ type generic_type_problems =
actual: Type.t;
expected: Type.Variable.Unary.t;
}
| ViolateConstraintsVariadic of {
actual: Type.OrderedTypes.t;
expected: Type.Variable.Variadic.List.t;
}
| UnexpectedKind of {
actual: Type.Parameter.t;
expected: Type.Variable.t;
Expand Down
48 changes: 48 additions & 0 deletions analysis/test/integration/typeVariableTest.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1834,6 +1834,53 @@ let test_list_variadics context =
()


let test_list_variadics_constraints context =
let assert_type_errors = assert_type_errors ~context in
assert_type_errors
{|
from typing import Generic, TypeVar
from typing_extensions import Literal
from pyre_extensions import ListVariadic
from pyre_extensions.type_variable_operators import Concatenate as Cat

Ts = ListVariadic("Ts", bound=int)
A = TypeVar("A")
class Vec(Generic[Ts]): ...

def f1(x : Vec[Cat[Ts,A]]) -> None: ...
def f2( *ts: Ts) -> Vec[Cat[Ts,float]]: ...
def f3( *ts: Ts) -> Vec[[int,float]]: ...
|}
[
"Invalid type parameters [24]: Type parameter list `Concatenate[test.Ts, Variable[test.A]]` \
violates constraints on `Ts` in generic type `Vec`.";
"Invalid type parameters [24]: Type parameter list `Concatenate[test.Ts, float]` violates \
constraints on `Ts` in generic type `Vec`.";
"Invalid type parameters [24]: Type parameter list `int, float` violates constraints on `Ts` \
in generic type `Vec`.";
];
assert_type_errors
{|
from typing import Generic, TypeVar
from pyre_extensions import ListVariadic

Ts1 = ListVariadic("Ts1", bound=int)
Ts2 = ListVariadic("Ts2", bound=float)

class Vec1(Generic[Ts1]): ...
class Vec2(Generic[Ts2]): ...

def f1(x: Vec1[Ts1]) -> None: ...
def f2(x: Vec2[Ts2]) -> None: ...
def g1(x: Vec1[Ts2]) -> None: ...
def g2(x: Vec2[Ts1]) -> None: ...
|}
[
"Invalid type parameters [24]: Type parameter list `test.Ts2` violates constraints on `Ts1` \
in generic type `Vec1`.";
]


let test_map context =
let assert_type_errors = assert_type_errors ~context in
assert_type_errors
Expand Down Expand Up @@ -2424,6 +2471,7 @@ let () =
"single_explicit_error" >:: test_single_explicit_error;
"callable_parameter_variadics" >:: test_callable_parameter_variadics;
"list_variadics" >:: test_list_variadics;
"list_variadics_constraints" >:: test_list_variadics_constraints;
"map" >:: test_map;
"user_defined_variadics" >:: test_user_defined_variadics;
"concatenation" >:: test_concatenation_operator;
Expand Down
11 changes: 10 additions & 1 deletion analysis/test/typeTest.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2082,7 +2082,16 @@ let test_parse_type_variable_declarations _ =
assert_parses_declaration
"pyre_extensions.ListVariadic('Ts')"
(Type.Variable.ListVariadic (Type.Variable.Variadic.List.create "target"));
assert_declaration_does_not_parse "pyre_extensions.ListVariadic('Ts', int, str)";
assert_parses_declaration
"pyre_extensions.ListVariadic('Ts', int, str)"
(Type.Variable.ListVariadic
(Type.Variable.Variadic.List.create
"target"
~constraints:(Explicit [Type.Primitive "int"; Type.Primitive "str"])));
assert_parses_declaration
"pyre_extensions.ListVariadic('Ts', bound=int)"
(Type.Variable.ListVariadic
(Type.Variable.Variadic.List.create "target" ~constraints:(Bound (Type.Primitive "int"))));
()


Expand Down
72 changes: 45 additions & 27 deletions analysis/type.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2741,12 +2741,10 @@ let rec create_logic ~aliases ~variable_aliases { Node.value = expression; _ } =
in
List.find_map ~f:bound arguments
in
if not (List.is_empty explicits) then
Record.Variable.Explicit explicits
else if Option.is_some bound then
Bound (Option.value_exn bound)
else
Unconstrained
match explicits, bound with
| [], Some bound -> Record.Variable.Bound bound
| explicits, _ when List.length explicits > 0 -> Explicit explicits
| _ -> Unconstrained
in
let variance =
let variance_definition = function
Expand Down Expand Up @@ -4101,28 +4099,48 @@ end = struct


let parse_declaration value ~target =
match value with
| {
Node.value =
Expression.Call
{
callee =
{
Node.value =
Name
(Name.Attribute
{
base = { Node.value = Name (Name.Identifier "pyre_extensions"); _ };
attribute = "ListVariadic";
special = false;
});
_;
};
arguments = [{ Call.Argument.value = { Node.value = String _; _ }; _ }];
};
_;
} ->
match Node.value value with
| Expression.Call { callee; arguments = [{ value = { Node.value = String _; _ }; _ }] }
when name_is ~name:"pyre_extensions.ListVariadic" callee ->
Some (create (Reference.show target))
| Call
{
callee;
arguments = { Call.Argument.value = { Node.value = String _; _ }; _ } :: arguments;
}
when name_is ~name:"pyre_extensions.ListVariadic" callee ->
let constraints =
let explicits =
let explicit = function
| {
Call.Argument.name = None;
value = { Node.value = Name (Name.Identifier identifier); _ };
} ->
let identifier = Identifier.sanitized identifier in
Some (Primitive identifier)
| _ -> None
in
List.filter_map ~f:explicit arguments
in
let bound =
let bound = function
| {
Call.Argument.value = { Node.value = Name (Name.Identifier identifier); _ };
name = Some { Node.value = bound; _ };
}
when String.equal (Identifier.sanitized bound) "bound" ->
let identifier = Identifier.sanitized identifier in
Some (Primitive identifier)
| _ -> None
in
List.find_map ~f:bound arguments
in
match explicits, bound with
| [], Some bound -> Record.Variable.Bound bound
| explicits, _ when List.length explicits > 0 -> Explicit explicits
| _ -> Unconstrained
in
Some (create (Reference.show target) ~constraints)
| _ -> None
end
end
Expand Down
29 changes: 26 additions & 3 deletions analysis/type.mli
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,16 @@ module Record : sig
end

module RecordList : sig
type 'annotation record [@@deriving compare, eq, sexp, show, hash]
type 'annotation record = {
name: Identifier.t;
constraints: 'annotation constraints;
variance: variance;
state: state;
namespace: RecordNamespace.t;
}
[@@deriving compare, eq, sexp, show, hash]

val name : 'a record -> string
end
end

Expand All @@ -64,10 +73,24 @@ module Record : sig
module OrderedTypes : sig
module RecordConcatenate : sig
module Middle : sig
type 'annotation t [@@deriving compare, eq, sexp, show, hash]
type 'annotation t = {
variable: 'annotation Variable.RecordVariadic.RecordList.record;
mappers: Identifier.t list;
}
[@@deriving compare, eq, sexp, show, hash]
end

type ('middle, 'outer) t [@@deriving compare, eq, sexp, show, hash]
type 'annotation wrapping = {
head: 'annotation list;
tail: 'annotation list;
}
[@@deriving compare, eq, sexp, show, hash]

type ('middle, 'annotation) t = {
middle: 'middle;
wrapping: 'annotation wrapping;
}
[@@deriving compare, eq, sexp, show, hash]
end

type 'annotation record =
Expand Down
Loading