diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index aee68f74eb5..5632c253437 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -1,16 +1,18 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe +import logging from typing import cast from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass +logger = logging.getLogger(__name__) + class ConvertExpandCopyToRepeatPass(ExportPass): """ @@ -41,6 +43,14 @@ def call_operator(self, op, args, kwargs, meta): multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1 for i in range(expanded_rank) ] + + if all((x == 1 for x in multiples)): + # All dimensions/repetitions occur only once. Remove node + # altogether since it's in practice just a copy. + logger.warning("Found redundant expand node (no-op). Removing it.") + + return args[0] + return super().call_operator( op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta ) diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index 9750f660003..b644e729bb4 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -37,14 +36,14 @@ class Expand(torch.nn.Module): # (input tensor, multiples) test_parameters = [ (torch.rand(1), (2,)), - (torch.randn(1, 4), (1, -1)), (torch.randn(1), (2, 2, 4)), (torch.randn(1, 1, 1, 5), (1, 4, -1, -1)), - (torch.randn(1, 1, 192), (1, -1, -1)), (torch.randn(1, 1), (1, 2, 2, 4)), (torch.randn(1, 1), (2, 2, 2, 4)), (torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)), (torch.rand(1, 1, 2, 2), (4, 3, -1, 2)), + (torch.randn(1, 4), (1, -1)), + (torch.randn(1, 1, 192), (1, -1, -1)), ] def forward(self, x: torch.Tensor, m: Sequence): @@ -117,7 +116,7 @@ def test_expand_tosa_MI(self, test_input, multiples): def test_expand_tosa_BI(self, test_input, multiples): self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples)) - @parameterized.expand(Expand.test_parameters[:-3]) + @parameterized.expand(Expand.test_parameters[:-5]) @pytest.mark.corstone_fvp def test_expand_u55_BI(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( @@ -125,15 +124,24 @@ def test_expand_u55_BI(self, test_input, multiples): ) # MLETORCH-629: Expand does not work on FVP with batch>1 - @parameterized.expand(Expand.test_parameters[-3:]) + @parameterized.expand(Expand.test_parameters[-5:-2]) @pytest.mark.corstone_fvp @conftest.expectedFailureOnFVP + def test_expand_u55_BI_xfails_on_fvp(self, test_input, multiples): + self._test_expand_ethosu_BI_pipeline( + common.get_u55_compile_spec(), self.Expand(), (test_input, multiples) + ) + + @parameterized.expand(Expand.test_parameters[-2:]) + @pytest.mark.xfail( + reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs" + ) def test_expand_u55_BI_xfails(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( common.get_u55_compile_spec(), self.Expand(), (test_input, multiples) ) - @parameterized.expand(Expand.test_parameters[:-3]) + @parameterized.expand(Expand.test_parameters[:-5]) @pytest.mark.corstone_fvp def test_expand_u85_BI(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( @@ -141,10 +149,19 @@ def test_expand_u85_BI(self, test_input, multiples): ) # MLETORCH-629: Expand does not work on FVP with batch>1 - @parameterized.expand(Expand.test_parameters[-3:]) + @parameterized.expand(Expand.test_parameters[-5:-2]) @pytest.mark.corstone_fvp @conftest.expectedFailureOnFVP - def test_expand_u85_BI_xfails(self, test_input, multiples): + def test_expand_u85_BI_xfails_on_fvp(self, test_input, multiples): + self._test_expand_ethosu_BI_pipeline( + common.get_u85_compile_spec(), self.Expand(), (test_input, multiples) + ) + + @parameterized.expand(Expand.test_parameters[-2:]) + @pytest.mark.xfail( + reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs" + ) + def test_expand_u85_xfails(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( common.get_u85_compile_spec(), self.Expand(), (test_input, multiples) )