Skip to content

Commit 09c0ce9

Browse files
Arm backend: Add support alias_copy operator (#10199)
- Add NodeVisitor factory for IDENTITY ops - Add alias_copy support - Add alias_copy tests - Move getitem to new factory Signed-off-by: Iliyan Georgiev <[email protected]>
1 parent 64fdebe commit 09c0ce9

File tree

6 files changed

+138
-37
lines changed

6 files changed

+138
-37
lines changed

backends/arm/operator_support/tosa_supported_operators.py

+1
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def is_node_supported(
229229
exir_ops.edge.aten.__lshift__.Scalar,
230230
torch.ops.aten.scalar_tensor.default,
231231
exir_ops.edge.aten.gelu.default,
232+
exir_ops.edge.aten.alias_copy.default,
232233
]
233234

234235
return supported

backends/arm/operators/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
op_erf,
2323
op_exp,
2424
op_ge,
25-
op_get_item,
2625
op_gt,
2726
op_le,
2827
op_log,
@@ -51,5 +50,6 @@
5150
op_view,
5251
op_where,
5352
ops_binary,
53+
ops_identity,
5454
ops_unary,
5555
)

backends/arm/operators/op_get_item.py

-35
This file was deleted.
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import torch
11+
import torch.fx
12+
13+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
14+
15+
from executorch.backends.arm.operators.node_visitor import (
16+
NodeVisitor,
17+
register_node_visitor,
18+
)
19+
from executorch.backends.arm.tosa_mapping import TosaArg
20+
21+
22+
def identity_operator_factory(identity_target: str):
23+
"""
24+
Creates and registers NodeVisitors for operators that map directly
25+
to a TOSA IDENTITY op.
26+
"""
27+
28+
class IdentityOperatorVisitor(NodeVisitor):
29+
target = identity_target
30+
31+
def define_node(
32+
self,
33+
node: torch.fx.Node,
34+
tosa_graph: ts.TosaSerializer,
35+
inputs: List[TosaArg],
36+
output: TosaArg,
37+
) -> None:
38+
# Simply add an identityOp
39+
tosa_graph.addOperator(
40+
ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]
41+
)
42+
43+
register_node_visitor(IdentityOperatorVisitor)
44+
45+
46+
identity_operator_factory("getitem")
47+
identity_operator_factory("aten.alias_copy.default")

backends/arm/quantizer/quantization_annotator.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ def _match_pattern(
244244
operator.getitem,
245245
]
246246

247+
_one_to_one_shared_input_or_input_act_qspec = [
248+
torch.ops.aten.adaptive_avg_pool2d.default,
249+
torch.ops.aten.alias_copy.default,
250+
]
251+
247252

248253
def get_quant_properties( # noqa: C901
249254
node: Node, gm: torch.fx.GraphModule, quantization_config
@@ -332,7 +337,7 @@ def any_or_hardtanh_min_zero(n: Node):
332337
_QuantProperty(2, shared_qspec), # type: ignore[arg-type]
333338
]
334339
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
335-
elif node.target == torch.ops.aten.adaptive_avg_pool2d.default:
340+
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
336341
input_qspec = (
337342
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
338343
if arm_quantizer_utils.is_output_annotated(node.args[0]) # type: ignore
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU55PipelineBI,
12+
EthosU85PipelineBI,
13+
TosaPipelineBI,
14+
TosaPipelineMI,
15+
)
16+
17+
input_t1 = Tuple[torch.Tensor]
18+
19+
20+
class AliasCopy(torch.nn.Module):
21+
"""
22+
Tests proper handling of alias_copy when used directly.
23+
24+
alias_copy can also appear from PyTorch/ExecuTorch optimizations
25+
such as `x.transpose(0, 0)`. This is optimized to an alias_copy but
26+
not before dq/q operators are added.
27+
"""
28+
29+
aten_op = "torch.ops.aten.alias_copy.default"
30+
exir_op = "executorch_exir_dialects_edge__ops_aten_alias_copy_default"
31+
32+
test_data: dict[input_t1] = {
33+
"1d_ramp": (torch.arange(-16, 16, 0.2),),
34+
"2d_ones": (torch.ones(5, 5),),
35+
"3d_rand": (torch.rand(3, 5, 5),),
36+
"4d_zeros": (torch.zeros(1, 10, 10, 10),),
37+
}
38+
39+
def __init__(self):
40+
super().__init__()
41+
42+
def forward(self, x: torch.Tensor):
43+
return torch.alias_copy(x)
44+
45+
46+
@common.parametrize("test_data", AliasCopy.test_data)
47+
def test_alias_copy_tosa_MI(test_data: input_t1):
48+
TosaPipelineMI[input_t1](
49+
AliasCopy(),
50+
test_data,
51+
AliasCopy.aten_op,
52+
AliasCopy.exir_op,
53+
).run()
54+
55+
56+
@common.parametrize("test_data", AliasCopy.test_data)
57+
def test_alias_copy_tosa_BI(test_data: input_t1):
58+
TosaPipelineBI[input_t1](
59+
AliasCopy(),
60+
test_data,
61+
AliasCopy.aten_op,
62+
AliasCopy.exir_op,
63+
).run()
64+
65+
66+
@common.parametrize("test_data", AliasCopy.test_data)
67+
def test_alias_copy_u55_BI(test_data: input_t1):
68+
EthosU55PipelineBI[input_t1](
69+
AliasCopy(),
70+
test_data,
71+
AliasCopy.aten_op,
72+
AliasCopy.exir_op,
73+
).run()
74+
75+
76+
@common.parametrize("test_data", AliasCopy.test_data)
77+
def test_alias_copy_u85_BI(test_data: input_t1):
78+
EthosU85PipelineBI[input_t1](
79+
AliasCopy(),
80+
test_data,
81+
AliasCopy.aten_op,
82+
AliasCopy.exir_op,
83+
).run()

0 commit comments

Comments
 (0)