Skip to content

Arm backend: Remove no-op repeat nodes in ConvertExpandCopyToRepeatPass #10137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
from typing import cast

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

logger = logging.getLogger(__name__)


class ConvertExpandCopyToRepeatPass(ExportPass):
"""
Expand Down Expand Up @@ -41,6 +43,14 @@ def call_operator(self, op, args, kwargs, meta):
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
for i in range(expanded_rank)
]

if all((x == 1 for x in multiples)):
# All dimensions/repetitions occur only once. Remove node
# altogether since it's in practice just a copy.
logger.warning("Found redundant expand node (no-op). Removing it.")

return args[0]

return super().call_operator(
op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta
)
33 changes: 25 additions & 8 deletions backends/arm/test/ops/test_expand.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -37,14 +36,14 @@ class Expand(torch.nn.Module):
# (input tensor, multiples)
test_parameters = [
(torch.rand(1), (2,)),
(torch.randn(1, 4), (1, -1)),
(torch.randn(1), (2, 2, 4)),
(torch.randn(1, 1, 1, 5), (1, 4, -1, -1)),
(torch.randn(1, 1, 192), (1, -1, -1)),
(torch.randn(1, 1), (1, 2, 2, 4)),
(torch.randn(1, 1), (2, 2, 2, 4)),
(torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)),
(torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
(torch.randn(1, 4), (1, -1)),
(torch.randn(1, 1, 192), (1, -1, -1)),
]

def forward(self, x: torch.Tensor, m: Sequence):
Expand Down Expand Up @@ -117,34 +116,52 @@ def test_expand_tosa_MI(self, test_input, multiples):
def test_expand_tosa_BI(self, test_input, multiples):
self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))

@parameterized.expand(Expand.test_parameters[:-3])
@parameterized.expand(Expand.test_parameters[:-5])
@pytest.mark.corstone_fvp
def test_expand_u55_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
)

# MLETORCH-629: Expand does not work on FVP with batch>1
@parameterized.expand(Expand.test_parameters[-3:])
@parameterized.expand(Expand.test_parameters[-5:-2])
@pytest.mark.corstone_fvp
@conftest.expectedFailureOnFVP
def test_expand_u55_BI_xfails_on_fvp(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
)

@parameterized.expand(Expand.test_parameters[-2:])
@pytest.mark.xfail(
reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs"
)
def test_expand_u55_BI_xfails(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
)

@parameterized.expand(Expand.test_parameters[:-3])
@parameterized.expand(Expand.test_parameters[:-5])
@pytest.mark.corstone_fvp
def test_expand_u85_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
)

# MLETORCH-629: Expand does not work on FVP with batch>1
@parameterized.expand(Expand.test_parameters[-3:])
@parameterized.expand(Expand.test_parameters[-5:-2])
@pytest.mark.corstone_fvp
@conftest.expectedFailureOnFVP
def test_expand_u85_BI_xfails(self, test_input, multiples):
def test_expand_u85_BI_xfails_on_fvp(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
)

@parameterized.expand(Expand.test_parameters[-2:])
@pytest.mark.xfail(
reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs"
)
def test_expand_u85_xfails(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
)
Loading