Skip to content

Commit 0867e3a

Browse files
Fix issue in EliminateConcatStridedSlice with compilation problem (#29932)
### Details: Fix errors - 1) the concat & StridedSlice all can be eliminated; - 2) the inputs of concat have dim value = 1 in concat axis - 2) Fix the bug caused by EliminateConcatStridedSlice Pass which did not consider the Nodes which do not have default output index, such as TopK. ### Tickets: - CVS-165105 - CVS-164371 - CVS-165403 --------- Signed-off-by: Zhai, Xuejun <[email protected]> Signed-off-by: Xu, He <[email protected]> Co-authored-by: Xu, He <[email protected]>
1 parent c65243c commit 0867e3a

File tree

2 files changed

+171
-7
lines changed

2 files changed

+171
-7
lines changed

src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp

+15-7
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,8 @@ pass::EliminateConcatStridedSlice::EliminateConcatStridedSlice() {
585585
if (end_constant_node == nullptr)
586586
return false;
587587
auto end_values = end_constant_node->cast_vector<int64_t>();
588+
if (end_values[concat_axis] > static_cast<int64_t>(concat->get_shape()[concat_axis]))
589+
end_values[concat_axis] = static_cast<int64_t>(concat->get_shape()[concat_axis]);
588590

589591
slice_out_index_in_concat.push_back(
590592
std::make_tuple(strided_slice_node, begin_values[concat_axis], end_values[concat_axis] - 1));
@@ -605,13 +607,15 @@ pass::EliminateConcatStridedSlice::EliminateConcatStridedSlice() {
605607
}
606608

607609
node_index_info_map mismatch_slices{};
610+
bool model_changed = false;
608611
for (const auto& [slice_node, slice_begin, slice_end] : slice_out_index_in_concat) {
609612
bool matched = false;
610613
for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) {
611614
if (slice_begin == concat_input_begin && slice_end == concat_input_end) {
612615
auto slice_outputs = slice_node->outputs();
613616
for (auto& slice_output : slice_outputs) {
614617
replace_output_update_name(slice_output, concat_input_node);
618+
model_changed = true;
615619
}
616620
matched = true;
617621
break;
@@ -620,18 +624,23 @@ pass::EliminateConcatStridedSlice::EliminateConcatStridedSlice() {
620624
if (!matched)
621625
mismatch_slices.push_back(std::make_tuple(slice_node, slice_begin, slice_end));
622626
}
627+
if (mismatch_slices.empty())
628+
return model_changed;
629+
630+
if (mismatch_slices.size() == slice_out_index_in_concat.size())
631+
return model_changed;
623632

624633
int64_t new_start_value{std::numeric_limits<int64_t>::max()};
625634
int64_t new_end_value{0};
626635
for (const auto& [slice_node, slice_begin, slice_end] : mismatch_slices) {
627636
for (const auto& [concat_input_node, concat_input_begin, concat_input_end] : in_index_in_concat) {
628-
if ((concat_input_begin <= slice_begin) && (concat_input_end > slice_begin)) {
637+
if ((concat_input_begin <= slice_begin) && (concat_input_end >= slice_begin)) {
629638
if (concat_input_begin < new_start_value)
630639
new_start_value = concat_input_begin;
631640
if (concat_input_end > new_end_value)
632641
new_end_value = concat_input_end;
633642
}
634-
if ((concat_input_begin < slice_end) && (concat_input_end >= slice_end)) {
643+
if ((concat_input_begin <= slice_end) && (concat_input_end >= slice_end)) {
635644
if (concat_input_begin < new_start_value)
636645
new_start_value = concat_input_begin;
637646
if (concat_input_end > new_end_value)
@@ -663,18 +672,17 @@ pass::EliminateConcatStridedSlice::EliminateConcatStridedSlice() {
663672
ov::as_type_ptr<ov::op::v0::Concat>(slice_node->get_users()[0])->get_axis() == concat_axis) {
664673
auto next_concat = ov::as_type_ptr<ov::op::v0::Concat>(slice_node->get_users()[0]);
665674
auto next_concat_inputs = next_concat->input_values();
666-
std::vector<std::shared_ptr<Node>> new_next_concat_inputs{};
675+
ov::OutputVector new_next_concat_inputs{};
667676
for (const auto& t : next_concat_inputs) {
668677
if (t.get_node_shared_ptr() == slice_node) {
669678
for (const auto& need_insert : new_concat_in_nodes) {
670679
new_next_concat_inputs.push_back(need_insert);
671680
}
672-
continue;
681+
} else {
682+
new_next_concat_inputs.push_back(t);
673683
}
674-
new_next_concat_inputs.push_back(t.get_node_shared_ptr());
675684
}
676-
auto new_next_concat_node =
677-
next_concat->clone_with_new_inputs(ov::as_output_vector(new_next_concat_inputs));
685+
auto new_next_concat_node = next_concat->clone_with_new_inputs(new_next_concat_inputs);
678686
replace_output_update_name(next_concat, new_next_concat_node);
679687
} else {
680688
std::vector<std::shared_ptr<Node>> new_slice_in_nodes{};

src/common/transformations/tests/common_optimizations/nop_elimination.cpp

+156
Original file line numberDiff line numberDiff line change
@@ -1598,6 +1598,51 @@ TEST_F(TransformationTestsF, EliminateConcatStridedSlice) {
15981598
}
15991599
}
16001600

1601+
TEST_F(TransformationTestsF, EliminateConcatStridedSliceAll) {
1602+
{
1603+
int64_t axis = 2;
1604+
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 1});
1605+
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 1});
1606+
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({param1, param2}), axis);
1607+
1608+
auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0});
1609+
auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 1});
1610+
auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1611+
begin_const1,
1612+
end_const1,
1613+
std::vector<int64_t>{1, 1, 0},
1614+
std::vector<int64_t>{1, 1, 0});
1615+
1616+
auto relu1 = std::make_shared<op::v0::Relu>(strided_slice1);
1617+
auto result1 = std::make_shared<op::v0::Result>(relu1);
1618+
1619+
auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 1});
1620+
auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 4});
1621+
auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1622+
begin_const2,
1623+
end_const2,
1624+
std::vector<int64_t>{1, 1, 0},
1625+
std::vector<int64_t>{1, 1, 0});
1626+
auto relu2 = std::make_shared<op::v0::Relu>(strided_slice2);
1627+
auto result2 = std::make_shared<op::v0::Result>(relu2);
1628+
1629+
model = std::make_shared<ov::Model>(ResultVector{result1, result2}, ParameterVector{param1, param2});
1630+
manager.register_pass<ov::pass::EliminateConcatStridedSlice>();
1631+
}
1632+
{
1633+
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 1});
1634+
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 1});
1635+
1636+
auto relu1 = std::make_shared<op::v0::Relu>(param1);
1637+
auto result1 = std::make_shared<op::v0::Result>(relu1);
1638+
1639+
auto relu2 = std::make_shared<op::v0::Relu>(param2);
1640+
auto result2 = std::make_shared<op::v0::Result>(relu2);
1641+
1642+
model_ref = std::make_shared<ov::Model>(ResultVector{result1, result2}, ParameterVector{param1, param2});
1643+
}
1644+
}
1645+
16011646
TEST_F(TransformationTestsF, EliminateConcatStridedSliceConcat) {
16021647
{
16031648
int64_t axis = 2;
@@ -1642,6 +1687,117 @@ TEST_F(TransformationTestsF, EliminateConcatStridedSliceConcat) {
16421687
}
16431688
}
16441689

1690+
TEST_F(TransformationTestsF, EliminateConcatStridedSliceConcatMismatch) {
1691+
{
1692+
int64_t axis = 2;
1693+
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 3});
1694+
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 4});
1695+
auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 5});
1696+
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({param1, param2, param3}), axis);
1697+
1698+
auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0});
1699+
auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 4});
1700+
auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1701+
begin_const1,
1702+
end_const1,
1703+
std::vector<int64_t>{1, 1, 0},
1704+
std::vector<int64_t>{1, 1, 0});
1705+
auto relu = std::make_shared<op::v0::Relu>(strided_slice1);
1706+
1707+
auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3});
1708+
auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 10});
1709+
auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1710+
begin_const2,
1711+
end_const2,
1712+
std::vector<int64_t>{1, 1, 0},
1713+
std::vector<int64_t>{1, 1, 0});
1714+
auto concat1 = make_shared<ov::op::v0::Concat>(ov::as_output_vector({relu, strided_slice2}), axis);
1715+
1716+
auto result = std::make_shared<op::v0::Result>(concat1);
1717+
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2, param3});
1718+
manager.register_pass<ov::pass::EliminateConcatStridedSlice>();
1719+
}
1720+
{
1721+
int64_t axis = 2;
1722+
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 3});
1723+
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 4});
1724+
auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2, 10, 5});
1725+
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({param1, param2, param3}), axis);
1726+
1727+
auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0});
1728+
auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 4});
1729+
auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1730+
begin_const1,
1731+
end_const1,
1732+
std::vector<int64_t>{1, 1, 0},
1733+
std::vector<int64_t>{1, 1, 0});
1734+
auto relu = std::make_shared<op::v0::Relu>(strided_slice1);
1735+
1736+
auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3});
1737+
auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 10});
1738+
auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1739+
begin_const2,
1740+
end_const2,
1741+
std::vector<int64_t>{1, 1, 0},
1742+
std::vector<int64_t>{1, 1, 0});
1743+
auto concat1 = make_shared<ov::op::v0::Concat>(ov::as_output_vector({relu, strided_slice2}), axis);
1744+
1745+
auto result = std::make_shared<op::v0::Result>(concat1);
1746+
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2, param3});
1747+
}
1748+
}
1749+
1750+
TEST_F(TransformationTestsF, EliminateConcatStridedSliceTopKConcat) {
1751+
{
1752+
int64_t axis = 2;
1753+
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 10, 3});
1754+
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 10, 4});
1755+
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({param1, param2}), axis);
1756+
1757+
auto begin_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 0});
1758+
auto end_const1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3});
1759+
auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1760+
begin_const1,
1761+
end_const1,
1762+
std::vector<int64_t>{1, 1, 0},
1763+
std::vector<int64_t>{1, 1, 0});
1764+
auto topk = std::make_shared<ov::op::v1::TopK>(strided_slice1,
1765+
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}),
1766+
axis,
1767+
op::v1::TopK::Mode::MAX,
1768+
op::v1::TopK::SortType::NONE);
1769+
1770+
auto begin_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 3});
1771+
auto end_const2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{3}, {0, 0, 7});
1772+
auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1773+
begin_const2,
1774+
end_const2,
1775+
std::vector<int64_t>{1, 1, 0},
1776+
std::vector<int64_t>{1, 1, 0});
1777+
auto topk_values = std::make_shared<ov::op::v0::Result>(topk->output(0));
1778+
auto concat1 = make_shared<ov::op::v0::Concat>(ov::as_output_vector({topk_values, strided_slice2}), axis);
1779+
1780+
auto result = std::make_shared<op::v0::Result>(concat1);
1781+
model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2});
1782+
manager.register_pass<ov::pass::EliminateConcatStridedSlice>();
1783+
}
1784+
{
1785+
int64_t axis = 2;
1786+
auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 10, 3});
1787+
auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 10, 4});
1788+
auto axis_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, axis);
1789+
auto topk = std::make_shared<ov::op::v1::TopK>(param1,
1790+
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}),
1791+
axis,
1792+
op::v1::TopK::Mode::MAX,
1793+
op::v1::TopK::SortType::NONE);
1794+
auto topk_values = std::make_shared<ov::op::v0::Result>(topk->output(0));
1795+
auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector({topk_values, param2}), axis);
1796+
auto result = std::make_shared<op::v0::Result>(concat);
1797+
model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2});
1798+
}
1799+
}
1800+
16451801
TEST_F(TransformationTestsF, EliminateConcatStridedSliceConcatDiffAxis) {
16461802
{
16471803
int64_t axis = 2;

0 commit comments

Comments
 (0)