Skip to content

Commit 04cacc6

Browse files
martinlsmMartin Lindström
and
Martin Lindström
authored
Arm backend: Remove no-op repeat nodes in ConvertExpandCopyToRepeatPass (#10137)
Update ConvertExpandCopyToRepeatPass to eliminate Repeat nodes where all repetition counts are one, as these are functionally equivalent to a simple copy. This reduces unnecessary operations in the graph. This patch affects two unit tests in backends/arm/test/ops/test_expand.py. The models (test subjects) in those unit tests are in the affected cases optimized such that the only computation node (torch.Tensor.expand) they each contain is optimized away by the modified ConvertExpandCopyToRepeatPass. This will make Vela complain saying that the model is empty and the TOSA file cannot be compiled. Set the affected unit tests to xfail since they are expected to be resolved later in another patch. Co-authored-by: Martin Lindström <[email protected]>
1 parent edab231 commit 04cacc6

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

backends/arm/_passes/convert_expand_copy_to_repeat.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

8+
import logging
99
from typing import cast
1010

1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import ExportPass
1313

14+
logger = logging.getLogger(__name__)
15+
1416

1517
class ConvertExpandCopyToRepeatPass(ExportPass):
1618
"""
@@ -41,6 +43,14 @@ def call_operator(self, op, args, kwargs, meta):
4143
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
4244
for i in range(expanded_rank)
4345
]
46+
47+
if all((x == 1 for x in multiples)):
48+
# All dimensions/repetitions occur only once. Remove node
49+
# altogether since it's in practice just a copy.
50+
logger.warning("Found redundant expand node (no-op). Removing it.")
51+
52+
return args[0]
53+
4454
return super().call_operator(
4555
op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta
4656
)

backends/arm/test/ops/test_expand.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -37,14 +36,14 @@ class Expand(torch.nn.Module):
3736
# (input tensor, multiples)
3837
test_parameters = [
3938
(torch.rand(1), (2,)),
40-
(torch.randn(1, 4), (1, -1)),
4139
(torch.randn(1), (2, 2, 4)),
4240
(torch.randn(1, 1, 1, 5), (1, 4, -1, -1)),
43-
(torch.randn(1, 1, 192), (1, -1, -1)),
4441
(torch.randn(1, 1), (1, 2, 2, 4)),
4542
(torch.randn(1, 1), (2, 2, 2, 4)),
4643
(torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)),
4744
(torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
45+
(torch.randn(1, 4), (1, -1)),
46+
(torch.randn(1, 1, 192), (1, -1, -1)),
4847
]
4948

5049
def forward(self, x: torch.Tensor, m: Sequence):
@@ -117,34 +116,52 @@ def test_expand_tosa_MI(self, test_input, multiples):
117116
def test_expand_tosa_BI(self, test_input, multiples):
118117
self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))
119118

120-
@parameterized.expand(Expand.test_parameters[:-3])
119+
@parameterized.expand(Expand.test_parameters[:-5])
121120
@pytest.mark.corstone_fvp
122121
def test_expand_u55_BI(self, test_input, multiples):
123122
self._test_expand_ethosu_BI_pipeline(
124123
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
125124
)
126125

127126
# MLETORCH-629: Expand does not work on FVP with batch>1
128-
@parameterized.expand(Expand.test_parameters[-3:])
127+
@parameterized.expand(Expand.test_parameters[-5:-2])
129128
@pytest.mark.corstone_fvp
130129
@conftest.expectedFailureOnFVP
130+
def test_expand_u55_BI_xfails_on_fvp(self, test_input, multiples):
131+
self._test_expand_ethosu_BI_pipeline(
132+
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
133+
)
134+
135+
@parameterized.expand(Expand.test_parameters[-2:])
136+
@pytest.mark.xfail(
137+
reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs"
138+
)
131139
def test_expand_u55_BI_xfails(self, test_input, multiples):
132140
self._test_expand_ethosu_BI_pipeline(
133141
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
134142
)
135143

136-
@parameterized.expand(Expand.test_parameters[:-3])
144+
@parameterized.expand(Expand.test_parameters[:-5])
137145
@pytest.mark.corstone_fvp
138146
def test_expand_u85_BI(self, test_input, multiples):
139147
self._test_expand_ethosu_BI_pipeline(
140148
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
141149
)
142150

143151
# MLETORCH-629: Expand does not work on FVP with batch>1
144-
@parameterized.expand(Expand.test_parameters[-3:])
152+
@parameterized.expand(Expand.test_parameters[-5:-2])
145153
@pytest.mark.corstone_fvp
146154
@conftest.expectedFailureOnFVP
147-
def test_expand_u85_BI_xfails(self, test_input, multiples):
155+
def test_expand_u85_BI_xfails_on_fvp(self, test_input, multiples):
156+
self._test_expand_ethosu_BI_pipeline(
157+
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
158+
)
159+
160+
@parameterized.expand(Expand.test_parameters[-2:])
161+
@pytest.mark.xfail(
162+
reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs"
163+
)
164+
def test_expand_u85_xfails(self, test_input, multiples):
148165
self._test_expand_ethosu_BI_pipeline(
149166
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
150167
)

0 commit comments

Comments
 (0)