Skip to content

Commit 91d16cb

Browse files
Hao Lufacebook-github-bot
Hao Lu
authored andcommitted
[Jit] Fix schema of aten::split int[] version (pytorch#69745)
Summary: Pull Request resolved: pytorch#69745 Missed in D31935573 (pytorch@6b44e75). Reviewed By: d1jang Differential Revision: D31889867 fbshipit-source-id: 417bd0b15db4891dbd641b35a803553f11d0d756
1 parent 9962bfb commit 91d16cb

File tree

3 files changed

+57
-2
lines changed

3 files changed

+57
-2
lines changed

test/backward_compatibility/check_backward_compatibility.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
("aten::grid_sampler_2d_backward", datetime.date(2021, 10, 21)),
6868
("prim::TensorExprDynamicGuard", datetime.date(2021, 11, 20)),
6969
("aten::split_with_sizes", datetime.date(2021, 11, 20)),
70-
("aten::split", datetime.date(2021, 11, 20)),
70+
("aten::split", datetime.date(2021, 12, 20)),
7171
("aten::vsplit", datetime.date(2021, 11, 20)),
7272
("aten::tensor_split", datetime.date(2021, 11, 20)),
7373
("aten::chunk", datetime.date(2021, 11, 20)),

test/cpp/jit/test_alias_analysis.cpp

+55
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,61 @@ TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
15321532
EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["z"]));
15331533
}
15341534

1535+
TEST(AliasRegistrationTest, ATenSplitIntListAliasCheck) {
1536+
auto graph = std::make_shared<Graph>();
1537+
std::unordered_map<std::string, Value*> vmap;
1538+
auto graph_string = R"IR(
1539+
graph():
1540+
%x : Tensor = prim::MakeTestTensor()
1541+
%0 : int = prim::Constant[value=0]()
1542+
%1 : int = prim::Constant[value=1]()
1543+
%2 : int = prim::Constant[value=2]()
1544+
%y : Tensor = aten::add(%x, %x, %0)
1545+
%lengths_list : int[] = prim::tolist(%1, %2)
1546+
%a : Tensor[] = aten::split(%y, %lengths_list, %0)
1547+
%b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1548+
%b1 : Tensor = aten::flatten(%b, %0, %1)
1549+
%c1 : Tensor = aten::flatten(%c, %0, %1)
1550+
%d : Tensor = aten::add(%b1, %c1, %0)
1551+
return (%d))IR";
1552+
1553+
torch::jit::parseIR(graph_string, graph.get(), vmap);
1554+
AliasDb aliasDb(
1555+
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
1556+
1557+
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
1558+
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
1559+
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
1560+
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
1561+
}
1562+
1563+
TEST(AliasRegistrationTest, ATenSplitIntAliasCheck) {
1564+
auto graph = std::make_shared<Graph>();
1565+
std::unordered_map<std::string, Value*> vmap;
1566+
auto graph_string = R"IR(
1567+
graph():
1568+
%x : Tensor = prim::MakeTestTensor()
1569+
%0 : int = prim::Constant[value=0]()
1570+
%1 : int = prim::Constant[value=1]()
1571+
%2 : int = prim::Constant[value=2]()
1572+
%y : Tensor = aten::add(%x, %x, %0)
1573+
%a : Tensor[] = aten::split(%y, %2, %0)
1574+
%b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1575+
%b1 : Tensor = aten::flatten(%b, %0, %1)
1576+
%c1 : Tensor = aten::flatten(%c, %0, %1)
1577+
%d : Tensor = aten::add(%b1, %c1, %0)
1578+
return (%d))IR";
1579+
1580+
torch::jit::parseIR(graph_string, graph.get(), vmap);
1581+
AliasDb aliasDb(
1582+
graph, /*isFrozen=*/false, /*enablePreciseTupleContainerAnalysis=*/true);
1583+
1584+
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
1585+
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
1586+
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
1587+
EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
1588+
}
1589+
15351590
TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
15361591
auto registry = torch::RegisterOperators().op(
15371592
"foo::rand12(Tensor(a) arg1) -> Tensor(b)",

torch/csrc/jit/runtime/register_special_ops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ void createTensorFromList(Stack& stack) {
245245
RegisterOperators reg({
246246
OperatorGenerator(
247247
TORCH_SELECTIVE_SCHEMA(
248-
"aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]"),
248+
"aten::split(Tensor(a -> *) self, int[] split_sizes, int dim=0) -> Tensor(a)[]"),
249249
[](Stack& stack) {
250250
RECORD_FUNCTION("split_with_sizes", last(stack, 3));
251251

0 commit comments

Comments
 (0)