Skip to content

Commit be364ac

Browse files
smessmerfacebook-github-bot
authored andcommitted
Specify overload name in function schema (pytorch#18037)
Summary: Pull Request resolved: pytorch#18037 The FunctionSchema can now store an overload name and the parser knows how to parse it. Specify like this: my_func.overload1(arg1: Tensor) -> Tensor my_func.overload2(arg1: Tensor, arg2: Tensor) -> Tensor Reviewed By: zdevito Differential Revision: D14467497 fbshipit-source-id: 8832b32f07351bb61090357b17b77a6a2fed3650
1 parent 7a3488e commit be364ac

29 files changed

+260
-224
lines changed

aten/src/ATen/core/function_schema.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,32 +67,37 @@ struct Argument {
6767
struct FunctionSchema {
6868
FunctionSchema(
6969
std::string name,
70+
std::string overload_name,
7071
std::vector<Argument> arguments,
7172
std::vector<Argument> returns,
7273
bool is_vararg = false,
7374
bool is_varret = false)
7475
: name_(std::move(name)),
76+
overload_name_(std::move(overload_name)),
7577
arguments_(std::move(arguments)),
7678
returns_(std::move(returns)),
7779
is_vararg_(is_vararg),
7880
is_varret_(is_varret) {}
7981

8082
FunctionSchema(
8183
Symbol name,
84+
std::string overload_name,
8285
std::vector<Argument> arguments,
8386
std::vector<Argument> returns,
8487
bool is_vararg = false,
8588
bool is_varret = false,
8689
std::vector<std::string> writes = {})
8790
: FunctionSchema(
8891
name.toQualString(),
92+
std::move(overload_name),
8993
std::move(std::move(arguments)),
9094
std::move(std::move(returns)),
9195
is_vararg,
9296
is_varret) {}
9397

9498
private:
9599
const std::string name_;
100+
const std::string overload_name_;
96101
const std::vector<Argument> arguments_;
97102
const std::vector<Argument> returns_;
98103
// if true then this schema takes an arbitrary number of additional arguments
@@ -106,6 +111,9 @@ struct FunctionSchema {
106111
const std::string& name() const {
107112
return name_;
108113
}
114+
const std::string& overload_name() const {
115+
return overload_name_;
116+
}
109117
const std::vector<Argument>& arguments() const {
110118
return arguments_;
111119
}

caffe2/core/c10_operator.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,10 @@ inline c10::FunctionSchema make_function_schema_for_c10(const char* OperatorName
9999
IValue());
100100

101101
return c10::FunctionSchema(
102-
std::string("_caffe2::") + OperatorName,
103-
std::move(actual_inputs), std::move(outputs));
102+
std::string("_caffe2::") + OperatorName,
103+
"",
104+
std::move(actual_inputs),
105+
std::move(outputs));
104106
}
105107

106108
}

caffe2/operators/experimental/c10/schemas/add.cc

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@ using caffe2::CPUContext;
66
namespace caffe2 {
77
namespace ops {
88
// TODO Parse schema string instead of creating FunctionSchema manually
9-
C10_DEFINE_OP_SCHEMA(Add, FunctionSchema(
10-
"_c10_experimental::Add",
11-
(std::vector<c10::Argument>{
12-
c10::Argument("input1"),
13-
c10::Argument("input2"),
14-
c10::Argument("output"),
15-
c10::Argument("legacy_broadcast", BoolType::get()),
16-
c10::Argument("axis", IntType::get())
17-
}), (std::vector<c10::Argument>{
18-
})
19-
));
9+
C10_DEFINE_OP_SCHEMA(
10+
Add,
11+
FunctionSchema(
12+
"_c10_experimental::Add",
13+
"",
14+
(std::vector<c10::Argument>{
15+
c10::Argument("input1"),
16+
c10::Argument("input2"),
17+
c10::Argument("output"),
18+
c10::Argument("legacy_broadcast", BoolType::get()),
19+
c10::Argument("axis", IntType::get())}),
20+
(std::vector<c10::Argument>{})));
2021
}
2122
}
2223

caffe2/operators/experimental/c10/schemas/averaged_loss.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ using caffe2::CPUContext;
77
namespace caffe2 {
88
namespace ops {
99
// TODO Parse schema string instead of creating FunctionSchema manually
10-
C10_DEFINE_OP_SCHEMA(AveragedLoss, FunctionSchema(
11-
"_c10_experimental::AveragedLoss",
12-
(std::vector<c10::Argument>{
13-
c10::Argument("input"),
14-
c10::Argument("output")
15-
}), (std::vector<c10::Argument>{
16-
})
17-
));
10+
C10_DEFINE_OP_SCHEMA(
11+
AveragedLoss,
12+
FunctionSchema(
13+
"_c10_experimental::AveragedLoss",
14+
"",
15+
(std::vector<c10::Argument>{c10::Argument("input"),
16+
c10::Argument("output")}),
17+
(std::vector<c10::Argument>{})));
1818
}
1919
}
2020

caffe2/operators/experimental/c10/schemas/batch_gather.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ using caffe2::CPUContext;
77
namespace caffe2 {
88
namespace ops {
99
// TODO Parse schema string instead of creating FunctionSchema manually
10-
C10_DEFINE_OP_SCHEMA(BatchGather, FunctionSchema(
11-
"_c10_experimental::BatchGather",
12-
(std::vector<c10::Argument>{
13-
c10::Argument("data"),
14-
c10::Argument("indices"),
15-
c10::Argument("output")
16-
}), (std::vector<c10::Argument>{
17-
})
18-
));
10+
C10_DEFINE_OP_SCHEMA(
11+
BatchGather,
12+
FunctionSchema(
13+
"_c10_experimental::BatchGather",
14+
"",
15+
(std::vector<c10::Argument>{c10::Argument("data"),
16+
c10::Argument("indices"),
17+
c10::Argument("output")}),
18+
(std::vector<c10::Argument>{})));
1919
}
2020
}
2121

caffe2/operators/experimental/c10/schemas/batch_matmul.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@ using caffe2::CPUContext;
77
namespace caffe2 {
88
namespace ops {
99
// TODO Parse schema string instead of creating FunctionSchema manually
10-
C10_DEFINE_OP_SCHEMA(BatchMatmul, FunctionSchema(
11-
"_c10_experimental::BatchMatmul",
12-
(std::vector<c10::Argument>{
13-
c10::Argument("A"),
14-
c10::Argument("B"),
15-
c10::Argument("output"),
16-
c10::Argument("trans_a", IntType::get()),
17-
c10::Argument("trans_b", IntType::get()),
18-
c10::Argument("broadcast", IntType::get())
19-
}), (std::vector<c10::Argument>{
20-
})
21-
));
10+
C10_DEFINE_OP_SCHEMA(
11+
BatchMatmul,
12+
FunctionSchema(
13+
"_c10_experimental::BatchMatmul",
14+
"",
15+
(std::vector<c10::Argument>{
16+
c10::Argument("A"),
17+
c10::Argument("B"),
18+
c10::Argument("output"),
19+
c10::Argument("trans_a", IntType::get()),
20+
c10::Argument("trans_b", IntType::get()),
21+
c10::Argument("broadcast", IntType::get())}),
22+
(std::vector<c10::Argument>{})));
2223
}
2324
}
2425

caffe2/operators/experimental/c10/schemas/cast.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@ using caffe2::CPUContext;
88
namespace caffe2 {
99
namespace ops {
1010
// TODO Parse schema string instead of creating FunctionSchema manually
11-
C10_DEFINE_OP_SCHEMA(Cast, FunctionSchema(
12-
"_c10_experimental::Cast",
13-
(std::vector<c10::Argument>{
14-
c10::Argument("input"),
15-
c10::Argument("output"),
16-
c10::Argument("to_dtype", IntType::get()),
17-
}), (std::vector<c10::Argument>{
18-
})
19-
));
11+
C10_DEFINE_OP_SCHEMA(
12+
Cast,
13+
FunctionSchema(
14+
"_c10_experimental::Cast",
15+
"",
16+
(std::vector<c10::Argument>{
17+
c10::Argument("input"),
18+
c10::Argument("output"),
19+
c10::Argument("to_dtype", IntType::get()),
20+
}),
21+
(std::vector<c10::Argument>{})));
2022
}
2123
}
2224

caffe2/operators/experimental/c10/schemas/concat.cc

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@ using caffe2::CPUContext;
77
namespace caffe2 {
88
namespace ops {
99
// TODO Parse schema string instead of creating FunctionSchema manually
10-
C10_DEFINE_OP_SCHEMA(Concat, FunctionSchema(
11-
"_c10_experimental::Concat",
12-
(std::vector<c10::Argument>{
13-
c10::Argument("inputs", ListType::ofTensors()),
14-
c10::Argument("output"),
15-
c10::Argument("split_info", FloatType::get()),
16-
c10::Argument("add", IntType::get()),
17-
c10::Argument("add_axis", IntType::get())
18-
}), (std::vector<c10::Argument>{
19-
})
20-
));
10+
C10_DEFINE_OP_SCHEMA(
11+
Concat,
12+
FunctionSchema(
13+
"_c10_experimental::Concat",
14+
"",
15+
(std::vector<c10::Argument>{
16+
c10::Argument("inputs", ListType::ofTensors()),
17+
c10::Argument("output"),
18+
c10::Argument("split_info", FloatType::get()),
19+
c10::Argument("add", IntType::get()),
20+
c10::Argument("add_axis", IntType::get())}),
21+
(std::vector<c10::Argument>{})));
2122
}
2223
}
2324

caffe2/operators/experimental/c10/schemas/enforce_finite.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ using caffe2::CPUContext;
77
namespace caffe2 {
88
namespace ops {
99
// TODO Parse schema string instead of creating FunctionSchema manually
10-
C10_DEFINE_OP_SCHEMA(EnforceFinite, FunctionSchema(
11-
"_c10_experimental::EnforceFinite",
12-
(std::vector<c10::Argument>{
13-
c10::Argument("input")
14-
}), (std::vector<c10::Argument>{
15-
})
16-
));
10+
C10_DEFINE_OP_SCHEMA(
11+
EnforceFinite,
12+
FunctionSchema(
13+
"_c10_experimental::EnforceFinite",
14+
"",
15+
(std::vector<c10::Argument>{c10::Argument("input")}),
16+
(std::vector<c10::Argument>{})));
1717
}
1818
}
1919

caffe2/operators/experimental/c10/schemas/expand_dims.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ using c10::ivalue::IntList;
99
namespace caffe2 {
1010
namespace ops {
1111
// TODO Parse schema string instead of creating FunctionSchema manually
12-
C10_DEFINE_OP_SCHEMA(ExpandDims, FunctionSchema(
13-
"_c10_experimental::ExpandDims",
14-
(std::vector<c10::Argument>{
15-
c10::Argument("input"),
16-
c10::Argument("output"),
17-
c10::Argument("dims", ListType::ofInts())
18-
}), (std::vector<c10::Argument>{
19-
})
20-
));
12+
C10_DEFINE_OP_SCHEMA(
13+
ExpandDims,
14+
FunctionSchema(
15+
"_c10_experimental::ExpandDims",
16+
"",
17+
(std::vector<c10::Argument>{c10::Argument("input"),
18+
c10::Argument("output"),
19+
c10::Argument("dims", ListType::ofInts())}),
20+
(std::vector<c10::Argument>{})));
2121
}
2222
}
2323

0 commit comments

Comments
 (0)