From 3e6dea976867a077d63aeaefa534b3b7fe2f6882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Wed, 9 Apr 2025 16:29:54 +0200 Subject: [PATCH] Remove no-op repeat nodes in ConvertExpandCopyToRepeatPass Update ConvertExpandCopyToRepeatPass to eliminate Repeat nodes where all repetition counts are one, as these are functionally equivalent to a simple copy. This reduces unnecessary operations in the graph. This patch affects two unit tests in backends/arm/test/ops/test_expand.py. The models (test subjects) in those unit tests are in the affected cases optimized such that the only computation node (torch.Tensor.expand) they each contain is optimized away by the modified ConvertExpandCopyToRepeatPass. This will make Vela complain saying that the model is empty and the TOSA file cannot be compiled. Set the affected unit tests to xfail since they are expected to be resolved later in another patch. Change-Id: I8f118c30eb9ecde09a7b8dcefd41b1fb65af8a3b --- .../_passes/convert_expand_copy_to_repeat.py | 14 ++++++-- backends/arm/test/ops/test_expand.py | 33 ++++++++++++++----- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/backends/arm/_passes/convert_expand_copy_to_repeat.py b/backends/arm/_passes/convert_expand_copy_to_repeat.py index aee68f74eb5..5632c253437 100644 --- a/backends/arm/_passes/convert_expand_copy_to_repeat.py +++ b/backends/arm/_passes/convert_expand_copy_to_repeat.py @@ -1,16 +1,18 @@ -# Copyright 2024 Arm Limited and/or its affiliates. -# All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe +import logging from typing import cast from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass +logger = logging.getLogger(__name__) + class ConvertExpandCopyToRepeatPass(ExportPass): """ @@ -41,6 +43,14 @@ def call_operator(self, op, args, kwargs, meta): multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1 for i in range(expanded_rank) ] + + if all((x == 1 for x in multiples)): + # All dimensions/repetitions occur only once. Remove node + # altogether since it's in practice just a copy. + logger.warning("Found redundant expand node (no-op). Removing it.") + + return args[0] + return super().call_operator( op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta ) diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index 9750f660003..b644e729bb4 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -37,14 +36,14 @@ class Expand(torch.nn.Module): # (input tensor, multiples) test_parameters = [ (torch.rand(1), (2,)), - (torch.randn(1, 4), (1, -1)), (torch.randn(1), (2, 2, 4)), (torch.randn(1, 1, 1, 5), (1, 4, -1, -1)), - (torch.randn(1, 1, 192), (1, -1, -1)), (torch.randn(1, 1), (1, 2, 2, 4)), (torch.randn(1, 1), (2, 2, 2, 4)), (torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)), (torch.rand(1, 1, 2, 2), (4, 3, -1, 2)), + (torch.randn(1, 4), (1, -1)), + (torch.randn(1, 1, 192), (1, -1, -1)), ] def forward(self, x: torch.Tensor, m: Sequence): @@ -117,7 +116,7 @@ def test_expand_tosa_MI(self, test_input, multiples): def test_expand_tosa_BI(self, test_input, multiples): self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples)) - @parameterized.expand(Expand.test_parameters[:-3]) + @parameterized.expand(Expand.test_parameters[:-5]) @pytest.mark.corstone_fvp def test_expand_u55_BI(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( @@ -125,15 +124,24 @@ def test_expand_u55_BI(self, test_input, multiples): ) # MLETORCH-629: Expand does not work on FVP with batch>1 - @parameterized.expand(Expand.test_parameters[-3:]) + @parameterized.expand(Expand.test_parameters[-5:-2]) @pytest.mark.corstone_fvp @conftest.expectedFailureOnFVP + def test_expand_u55_BI_xfails_on_fvp(self, test_input, multiples): + self._test_expand_ethosu_BI_pipeline( + common.get_u55_compile_spec(), self.Expand(), (test_input, multiples) + ) + + @parameterized.expand(Expand.test_parameters[-2:]) + @pytest.mark.xfail( + reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs" + ) def test_expand_u55_BI_xfails(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( common.get_u55_compile_spec(), self.Expand(), (test_input, multiples) ) - @parameterized.expand(Expand.test_parameters[:-3]) + @parameterized.expand(Expand.test_parameters[:-5]) @pytest.mark.corstone_fvp def test_expand_u85_BI(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( @@ -141,10 +149,19 @@ def test_expand_u85_BI(self, test_input, multiples): ) # MLETORCH-629: Expand does not work on FVP with batch>1 - @parameterized.expand(Expand.test_parameters[-3:]) + @parameterized.expand(Expand.test_parameters[-5:-2]) @pytest.mark.corstone_fvp @conftest.expectedFailureOnFVP - def test_expand_u85_BI_xfails(self, test_input, multiples): + def test_expand_u85_BI_xfails_on_fvp(self, test_input, multiples): + self._test_expand_ethosu_BI_pipeline( + common.get_u85_compile_spec(), self.Expand(), (test_input, multiples) + ) + + @parameterized.expand(Expand.test_parameters[-2:]) + @pytest.mark.xfail( + reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs" + ) + def test_expand_u85_xfails(self, test_input, multiples): self._test_expand_ethosu_BI_pipeline( common.get_u85_compile_spec(), self.Expand(), (test_input, multiples) )