@@ -1598,6 +1598,51 @@ TEST_F(TransformationTestsF, EliminateConcatStridedSlice) {
1598
1598
}
1599
1599
}
1600
1600
1601
+ TEST_F (TransformationTestsF, EliminateConcatStridedSliceAll) {
1602
+ {
1603
+ int64_t axis = 2 ;
1604
+ auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 1 });
1605
+ auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 1 });
1606
+ auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector ({param1, param2}), axis);
1607
+
1608
+ auto begin_const1 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 0 });
1609
+ auto end_const1 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 1 });
1610
+ auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1611
+ begin_const1,
1612
+ end_const1,
1613
+ std::vector<int64_t >{1 , 1 , 0 },
1614
+ std::vector<int64_t >{1 , 1 , 0 });
1615
+
1616
+ auto relu1 = std::make_shared<op::v0::Relu>(strided_slice1);
1617
+ auto result1 = std::make_shared<op::v0::Result>(relu1);
1618
+
1619
+ auto begin_const2 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 1 });
1620
+ auto end_const2 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 4 });
1621
+ auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1622
+ begin_const2,
1623
+ end_const2,
1624
+ std::vector<int64_t >{1 , 1 , 0 },
1625
+ std::vector<int64_t >{1 , 1 , 0 });
1626
+ auto relu2 = std::make_shared<op::v0::Relu>(strided_slice2);
1627
+ auto result2 = std::make_shared<op::v0::Result>(relu2);
1628
+
1629
+ model = std::make_shared<ov::Model>(ResultVector{result1, result2}, ParameterVector{param1, param2});
1630
+ manager.register_pass <ov::pass::EliminateConcatStridedSlice>();
1631
+ }
1632
+ {
1633
+ auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 1 });
1634
+ auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 1 });
1635
+
1636
+ auto relu1 = std::make_shared<op::v0::Relu>(param1);
1637
+ auto result1 = std::make_shared<op::v0::Result>(relu1);
1638
+
1639
+ auto relu2 = std::make_shared<op::v0::Relu>(param2);
1640
+ auto result2 = std::make_shared<op::v0::Result>(relu2);
1641
+
1642
+ model_ref = std::make_shared<ov::Model>(ResultVector{result1, result2}, ParameterVector{param1, param2});
1643
+ }
1644
+ }
1645
+
1601
1646
TEST_F (TransformationTestsF, EliminateConcatStridedSliceConcat) {
1602
1647
{
1603
1648
int64_t axis = 2 ;
@@ -1642,6 +1687,117 @@ TEST_F(TransformationTestsF, EliminateConcatStridedSliceConcat) {
1642
1687
}
1643
1688
}
1644
1689
1690
+ TEST_F (TransformationTestsF, EliminateConcatStridedSliceConcatMismatch) {
1691
+ {
1692
+ int64_t axis = 2 ;
1693
+ auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 3 });
1694
+ auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 4 });
1695
+ auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 5 });
1696
+ auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector ({param1, param2, param3}), axis);
1697
+
1698
+ auto begin_const1 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 0 });
1699
+ auto end_const1 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 4 });
1700
+ auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1701
+ begin_const1,
1702
+ end_const1,
1703
+ std::vector<int64_t >{1 , 1 , 0 },
1704
+ std::vector<int64_t >{1 , 1 , 0 });
1705
+ auto relu = std::make_shared<op::v0::Relu>(strided_slice1);
1706
+
1707
+ auto begin_const2 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 3 });
1708
+ auto end_const2 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 10 });
1709
+ auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1710
+ begin_const2,
1711
+ end_const2,
1712
+ std::vector<int64_t >{1 , 1 , 0 },
1713
+ std::vector<int64_t >{1 , 1 , 0 });
1714
+ auto concat1 = make_shared<ov::op::v0::Concat>(ov::as_output_vector ({relu, strided_slice2}), axis);
1715
+
1716
+ auto result = std::make_shared<op::v0::Result>(concat1);
1717
+ model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2, param3});
1718
+ manager.register_pass <ov::pass::EliminateConcatStridedSlice>();
1719
+ }
1720
+ {
1721
+ int64_t axis = 2 ;
1722
+ auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 3 });
1723
+ auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 4 });
1724
+ auto param3 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2 , 10 , 5 });
1725
+ auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector ({param1, param2, param3}), axis);
1726
+
1727
+ auto begin_const1 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 0 });
1728
+ auto end_const1 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 4 });
1729
+ auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1730
+ begin_const1,
1731
+ end_const1,
1732
+ std::vector<int64_t >{1 , 1 , 0 },
1733
+ std::vector<int64_t >{1 , 1 , 0 });
1734
+ auto relu = std::make_shared<op::v0::Relu>(strided_slice1);
1735
+
1736
+ auto begin_const2 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 3 });
1737
+ auto end_const2 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 10 });
1738
+ auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1739
+ begin_const2,
1740
+ end_const2,
1741
+ std::vector<int64_t >{1 , 1 , 0 },
1742
+ std::vector<int64_t >{1 , 1 , 0 });
1743
+ auto concat1 = make_shared<ov::op::v0::Concat>(ov::as_output_vector ({relu, strided_slice2}), axis);
1744
+
1745
+ auto result = std::make_shared<op::v0::Result>(concat1);
1746
+ model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2, param3});
1747
+ }
1748
+ }
1749
+
1750
+ TEST_F (TransformationTestsF, EliminateConcatStridedSliceTopKConcat) {
1751
+ {
1752
+ int64_t axis = 2 ;
1753
+ auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1 , 10 , 3 });
1754
+ auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1 , 10 , 4 });
1755
+ auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector ({param1, param2}), axis);
1756
+
1757
+ auto begin_const1 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 0 });
1758
+ auto end_const1 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 3 });
1759
+ auto strided_slice1 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1760
+ begin_const1,
1761
+ end_const1,
1762
+ std::vector<int64_t >{1 , 1 , 0 },
1763
+ std::vector<int64_t >{1 , 1 , 0 });
1764
+ auto topk = std::make_shared<ov::op::v1::TopK>(strided_slice1,
1765
+ ov::op::v0::Constant::create (ov::element::i64, ov::Shape{}, {1 }),
1766
+ axis,
1767
+ op::v1::TopK::Mode::MAX,
1768
+ op::v1::TopK::SortType::NONE);
1769
+
1770
+ auto begin_const2 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 3 });
1771
+ auto end_const2 = ov::op::v0::Constant::create (ov::element::i64, ov::Shape{3 }, {0 , 0 , 7 });
1772
+ auto strided_slice2 = std::make_shared<ov::op::v1::StridedSlice>(concat,
1773
+ begin_const2,
1774
+ end_const2,
1775
+ std::vector<int64_t >{1 , 1 , 0 },
1776
+ std::vector<int64_t >{1 , 1 , 0 });
1777
+ auto topk_values = std::make_shared<ov::op::v0::Result>(topk->output (0 ));
1778
+ auto concat1 = make_shared<ov::op::v0::Concat>(ov::as_output_vector ({topk_values, strided_slice2}), axis);
1779
+
1780
+ auto result = std::make_shared<op::v0::Result>(concat1);
1781
+ model = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2});
1782
+ manager.register_pass <ov::pass::EliminateConcatStridedSlice>();
1783
+ }
1784
+ {
1785
+ int64_t axis = 2 ;
1786
+ auto param1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1 , 10 , 3 });
1787
+ auto param2 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1 , 10 , 4 });
1788
+ auto axis_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, axis);
1789
+ auto topk = std::make_shared<ov::op::v1::TopK>(param1,
1790
+ ov::op::v0::Constant::create (ov::element::i64, ov::Shape{}, {1 }),
1791
+ axis,
1792
+ op::v1::TopK::Mode::MAX,
1793
+ op::v1::TopK::SortType::NONE);
1794
+ auto topk_values = std::make_shared<ov::op::v0::Result>(topk->output (0 ));
1795
+ auto concat = make_shared<ov::op::v0::Concat>(ov::as_output_vector ({topk_values, param2}), axis);
1796
+ auto result = std::make_shared<op::v0::Result>(concat);
1797
+ model_ref = std::make_shared<ov::Model>(ResultVector{result}, ParameterVector{param1, param2});
1798
+ }
1799
+ }
1800
+
1645
1801
TEST_F (TransformationTestsF, EliminateConcatStridedSliceConcatDiffAxis) {
1646
1802
{
1647
1803
int64_t axis = 2 ;
0 commit comments