@@ -1532,6 +1532,61 @@ TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
1532
1532
EXPECT_TRUE (aliasDb.mayContainAlias (vmap[" b" ], vmap[" z" ]));
1533
1533
}
1534
1534
1535
+ TEST (AliasRegistrationTest, ATenSplitIntListAliasCheck) {
1536
+ auto graph = std::make_shared<Graph>();
1537
+ std::unordered_map<std::string, Value*> vmap;
1538
+ auto graph_string = R"IR(
1539
+ graph():
1540
+ %x : Tensor = prim::MakeTestTensor()
1541
+ %0 : int = prim::Constant[value=0]()
1542
+ %1 : int = prim::Constant[value=1]()
1543
+ %2 : int = prim::Constant[value=2]()
1544
+ %y : Tensor = aten::add(%x, %x, %0)
1545
+ %lengths_list : int[] = prim::tolist(%1, %2)
1546
+ %a : Tensor[] = aten::split(%y, %lengths_list, %0)
1547
+ %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1548
+ %b1 : Tensor = aten::flatten(%b, %0, %1)
1549
+ %c1 : Tensor = aten::flatten(%c, %0, %1)
1550
+ %d : Tensor = aten::add(%b1, %c1, %0)
1551
+ return (%d))IR" ;
1552
+
1553
+ torch::jit::parseIR (graph_string, graph.get (), vmap);
1554
+ AliasDb aliasDb (
1555
+ graph, /* isFrozen=*/ false , /* enablePreciseTupleContainerAnalysis=*/ true );
1556
+
1557
+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" b" ]));
1558
+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" c" ]));
1559
+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" b1" ]));
1560
+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" c1" ]));
1561
+ }
1562
+
1563
+ TEST (AliasRegistrationTest, ATenSplitIntAliasCheck) {
1564
+ auto graph = std::make_shared<Graph>();
1565
+ std::unordered_map<std::string, Value*> vmap;
1566
+ auto graph_string = R"IR(
1567
+ graph():
1568
+ %x : Tensor = prim::MakeTestTensor()
1569
+ %0 : int = prim::Constant[value=0]()
1570
+ %1 : int = prim::Constant[value=1]()
1571
+ %2 : int = prim::Constant[value=2]()
1572
+ %y : Tensor = aten::add(%x, %x, %0)
1573
+ %a : Tensor[] = aten::split(%y, %2, %0)
1574
+ %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1575
+ %b1 : Tensor = aten::flatten(%b, %0, %1)
1576
+ %c1 : Tensor = aten::flatten(%c, %0, %1)
1577
+ %d : Tensor = aten::add(%b1, %c1, %0)
1578
+ return (%d))IR" ;
1579
+
1580
+ torch::jit::parseIR (graph_string, graph.get (), vmap);
1581
+ AliasDb aliasDb (
1582
+ graph, /* isFrozen=*/ false , /* enablePreciseTupleContainerAnalysis=*/ true );
1583
+
1584
+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" b" ]));
1585
+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" c" ]));
1586
+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" b1" ]));
1587
+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" c1" ]));
1588
+ }
1589
+
1535
1590
TEST (AliasRegistrationTest, PureWithAnnotationsShouldError2) {
1536
1591
auto registry = torch::RegisterOperators ().op (
1537
1592
" foo::rand12(Tensor(a) arg1) -> Tensor(b)" ,
0 commit comments