@@ -1532,6 +1532,61 @@ TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
15321532 EXPECT_TRUE (aliasDb.mayContainAlias (vmap[" b" ], vmap[" z" ]));
15331533}
15341534
1535+ TEST (AliasRegistrationTest, ATenSplitIntListAliasCheck) {
1536+ auto graph = std::make_shared<Graph>();
1537+ std::unordered_map<std::string, Value*> vmap;
1538+ auto graph_string = R"IR(
1539+ graph():
1540+ %x : Tensor = prim::MakeTestTensor()
1541+ %0 : int = prim::Constant[value=0]()
1542+ %1 : int = prim::Constant[value=1]()
1543+ %2 : int = prim::Constant[value=2]()
1544+ %y : Tensor = aten::add(%x, %x, %0)
1545+ %lengths_list : int[] = prim::tolist(%1, %2)
1546+ %a : Tensor[] = aten::split(%y, %lengths_list, %0)
1547+ %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1548+ %b1 : Tensor = aten::flatten(%b, %0, %1)
1549+ %c1 : Tensor = aten::flatten(%c, %0, %1)
1550+ %d : Tensor = aten::add(%b1, %c1, %0)
1551+ return (%d))IR" ;
1552+
1553+ torch::jit::parseIR (graph_string, graph.get (), vmap);
1554+ AliasDb aliasDb (
1555+ graph, /* isFrozen=*/ false , /* enablePreciseTupleContainerAnalysis=*/ true );
1556+
1557+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" b" ]));
1558+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" c" ]));
1559+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" b1" ]));
1560+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" c1" ]));
1561+ }
1562+
1563+ TEST (AliasRegistrationTest, ATenSplitIntAliasCheck) {
1564+ auto graph = std::make_shared<Graph>();
1565+ std::unordered_map<std::string, Value*> vmap;
1566+ auto graph_string = R"IR(
1567+ graph():
1568+ %x : Tensor = prim::MakeTestTensor()
1569+ %0 : int = prim::Constant[value=0]()
1570+ %1 : int = prim::Constant[value=1]()
1571+ %2 : int = prim::Constant[value=2]()
1572+ %y : Tensor = aten::add(%x, %x, %0)
1573+ %a : Tensor[] = aten::split(%y, %2, %0)
1574+ %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1575+ %b1 : Tensor = aten::flatten(%b, %0, %1)
1576+ %c1 : Tensor = aten::flatten(%c, %0, %1)
1577+ %d : Tensor = aten::add(%b1, %c1, %0)
1578+ return (%d))IR" ;
1579+
1580+ torch::jit::parseIR (graph_string, graph.get (), vmap);
1581+ AliasDb aliasDb (
1582+ graph, /* isFrozen=*/ false , /* enablePreciseTupleContainerAnalysis=*/ true );
1583+
1584+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" b" ]));
1585+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" c" ]));
1586+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" b1" ]));
1587+ EXPECT_TRUE (aliasDb.mayAlias (vmap[" y" ], vmap[" c1" ]));
1588+ }
1589+
15351590TEST (AliasRegistrationTest, PureWithAnnotationsShouldError2) {
15361591 auto registry = torch::RegisterOperators ().op (
15371592 " foo::rand12(Tensor(a) arg1) -> Tensor(b)" ,
0 commit comments