Skip to content

Commit 7fc176d

Browse files
authored
Arm backend: Add support to eq.Scalar (#9715)
Implement eq.Scalar by converting to eq.Tensor using replace_scalar_with_tensor_pass and match_arg_ranks_pass. * Convert eq.Tensor to eq.Scalar * Expand test_eq to test both eq.Tensor and eq.Scalar Signed-off-by: Fang-Ching <[email protected]>
1 parent f174d55 commit 7fc176d

File tree

6 files changed

+66
-48
lines changed

6 files changed

+66
-48
lines changed

backends/arm/_passes/match_arg_ranks_pass.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, exported_program):
4747
exir_ops.edge.aten.div.Tensor,
4848
exir_ops.edge.aten.bitwise_right_shift.Tensor,
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
50+
exir_ops.edge.aten.eq.Tensor,
5051
]
5152

5253
def _match_op_rank(self, graph_module, node, arg, max_rank):

backends/arm/operator_support/tosa_supported_operators.py

+2
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def is_node_supported(
158158
exir_ops.edge.aten.hardswish.default,
159159
exir_ops.edge.aten.div.Tensor,
160160
exir_ops.edge.aten.eq.Tensor,
161+
exir_ops.edge.aten.eq.Scalar,
161162
exir_ops.edge.aten.exp.default,
162163
exir_ops.edge.aten.log.default,
163164
exir_ops.edge.aten.linear.default,
@@ -235,6 +236,7 @@ class EthosU55NotSupported(OperatorSupportBase):
235236
exir_ops.edge.aten.amax.default, # REDUCE_MAX
236237
exir_ops.edge.aten.amin.default, # REDUCE_MIN
237238
exir_ops.edge.aten.eq.Tensor,
239+
exir_ops.edge.aten.eq.Scalar,
238240
exir_ops.edge.aten.ge.Tensor,
239241
exir_ops.edge.aten.gt.Tensor,
240242
exir_ops.edge.aten.le.Tensor,

backends/arm/test/models/test_conformer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,10 @@ class TestConformer(unittest.TestCase):
3131
# .to_executorch step, i.e. after Arm partitioner.
3232
ops_after_partitioner = {
3333
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
34-
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
3534
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
3635
"torch.ops.aten._assert_scalar.default": 10,
3736
"torch.ops.aten._local_scalar_dense.default": 1,
38-
"torch.ops.higher_order.executorch_call_delegate": 6,
37+
"torch.ops.higher_order.executorch_call_delegate": 4,
3938
}
4039

4140
dim = 16

backends/arm/test/models/test_llama.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_llama_tosa_MI(self):
114114
)
115115
.export()
116116
.to_edge_transform_and_lower()
117-
.check_count({"torch.ops.higher_order.executorch_call_delegate": 26})
117+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 14})
118118
.to_executorch()
119119
.run_method_and_compare_outputs(
120120
inputs=llama_inputs,

backends/arm/test/ops/test_eq.py

+59-45
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from typing import Tuple
77

8-
import pytest
98
import torch
109
from executorch.backends.arm.test import common
1110

@@ -16,13 +15,15 @@
1615
TosaPipelineMI,
1716
)
1817

19-
aten_op = "torch.ops.aten.eq.Tensor"
20-
exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
2118

2219
input_t = Tuple[torch.Tensor]
2320

2421

2522
class Equal(torch.nn.Module):
23+
aten_op_BI = "torch.ops.aten.eq.Tensor"
24+
aten_op_MI = "torch.ops.aten.eq.Scalar"
25+
exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
26+
2627
def __init__(self, input, other):
2728
super().__init__()
2829
self.input_ = input
@@ -31,106 +32,119 @@ def __init__(self, input, other):
3132
def forward(
3233
self,
3334
input_: torch.Tensor,
34-
other_: torch.Tensor,
35+
other_: torch.Tensor | int | float,
3536
):
3637
return input_ == other_
3738

3839
def get_inputs(self):
3940
return (self.input_, self.other_)
4041

4142

42-
op_eq_rank1_ones = Equal(
43+
op_eq_tensor_rank1_ones = Equal(
4344
torch.ones(5),
4445
torch.ones(5),
4546
)
46-
op_eq_rank2_rand = Equal(
47+
op_eq_tensor_rank2_rand = Equal(
4748
torch.rand(4, 5),
4849
torch.rand(1, 5),
4950
)
50-
op_eq_rank3_randn = Equal(
51+
op_eq_tensor_rank3_randn = Equal(
5152
torch.randn(10, 5, 2),
5253
torch.randn(10, 5, 2),
5354
)
54-
op_eq_rank4_randn = Equal(
55+
op_eq_tensor_rank4_randn = Equal(
5556
torch.randn(3, 2, 2, 2),
5657
torch.randn(3, 2, 2, 2),
5758
)
5859

59-
test_data_common = {
60-
"eq_rank1_ones": op_eq_rank1_ones,
61-
"eq_rank2_rand": op_eq_rank2_rand,
62-
"eq_rank3_randn": op_eq_rank3_randn,
63-
"eq_rank4_randn": op_eq_rank4_randn,
60+
op_eq_scalar_rank1_ones = Equal(torch.ones(5), 1.0)
61+
op_eq_scalar_rank2_rand = Equal(torch.rand(4, 5), 0.2)
62+
op_eq_scalar_rank3_randn = Equal(torch.randn(10, 5, 2), -0.1)
63+
op_eq_scalar_rank4_randn = Equal(torch.randn(3, 2, 2, 2), 0.3)
64+
65+
test_data_tensor = {
66+
"eq_tensor_rank1_ones": op_eq_tensor_rank1_ones,
67+
"eq_tensor_rank2_rand": op_eq_tensor_rank2_rand,
68+
"eq_tensor_rank3_randn": op_eq_tensor_rank3_randn,
69+
"eq_tensor_rank4_randn": op_eq_tensor_rank4_randn,
6470
}
6571

72+
test_data_scalar = {
73+
"eq_scalar_rank1_ones": op_eq_scalar_rank1_ones,
74+
"eq_scalar_rank2_rand": op_eq_scalar_rank2_rand,
75+
"eq_scalar_rank3_randn": op_eq_scalar_rank3_randn,
76+
"eq_scalar_rank4_randn": op_eq_scalar_rank4_randn,
77+
}
78+
79+
80+
@common.parametrize("test_module", test_data_tensor)
81+
def test_eq_tensor_tosa_MI(test_module):
82+
pipeline = TosaPipelineMI[input_t](
83+
test_module, test_module.get_inputs(), Equal.aten_op_BI, Equal.exir_op
84+
)
85+
pipeline.run()
6686

67-
@common.parametrize("test_module", test_data_common)
68-
def test_eq_tosa_MI(test_module):
87+
88+
@common.parametrize("test_module", test_data_scalar)
89+
def test_eq_scalar_tosa_MI(test_module):
6990
pipeline = TosaPipelineMI[input_t](
70-
test_module, test_module.get_inputs(), aten_op, exir_op
91+
test_module,
92+
test_module.get_inputs(),
93+
Equal.aten_op_MI,
94+
Equal.exir_op,
7195
)
7296
pipeline.run()
7397

7498

75-
@common.parametrize("test_module", test_data_common)
99+
@common.parametrize("test_module", test_data_tensor | test_data_scalar)
76100
def test_eq_tosa_BI(test_module):
77101
pipeline = TosaPipelineBI[input_t](
78-
test_module, test_module.get_inputs(), aten_op, exir_op
102+
test_module, test_module.get_inputs(), Equal.aten_op_BI, Equal.exir_op
79103
)
80104
pipeline.run()
81105

82106

83-
@common.parametrize("test_module", test_data_common)
84-
def test_eq_u55_BI(test_module):
107+
@common.parametrize("test_module", test_data_tensor)
108+
@common.XfailIfNoCorstone300
109+
def test_eq_tensor_u55_BI(test_module):
85110
# EQUAL is not supported on U55.
86111
pipeline = OpNotSupportedPipeline[input_t](
87112
test_module,
88113
test_module.get_inputs(),
89114
"TOSA-0.80+BI+u55",
90-
{exir_op: 1},
91-
)
92-
pipeline.run()
93-
94-
95-
@common.parametrize("test_module", test_data_common)
96-
def test_eq_u85_BI(test_module):
97-
pipeline = EthosU85PipelineBI[input_t](
98-
test_module,
99-
test_module.get_inputs(),
100-
aten_op,
101-
exir_op,
102-
run_on_fvp=False,
103-
use_to_edge_transform_and_lower=True,
115+
{Equal.exir_op: 1},
104116
)
105117
pipeline.run()
106118

107119

108-
@common.parametrize("test_module", test_data_common)
109-
@pytest.mark.skip(reason="The same as test_eq_u55_BI")
110-
def test_eq_u55_BI_on_fvp(test_module):
120+
@common.parametrize("test_module", test_data_scalar)
121+
@common.XfailIfNoCorstone300
122+
def test_eq_scalar_u55_BI(test_module):
111123
# EQUAL is not supported on U55.
112124
pipeline = OpNotSupportedPipeline[input_t](
113125
test_module,
114126
test_module.get_inputs(),
115127
"TOSA-0.80+BI+u55",
116-
{exir_op: 1},
128+
{Equal.exir_op: 1},
129+
n_expected_delegates=1,
117130
)
118131
pipeline.run()
119132

120133

121134
@common.parametrize(
122135
"test_module",
123-
test_data_common,
124-
xfails={"eq_rank4_randn": "4D fails because boolean Tensors can't be subtracted"},
136+
test_data_tensor | test_data_scalar,
137+
xfails={
138+
"eq_tensor_rank4_randn": "4D fails because boolean Tensors can't be subtracted",
139+
},
125140
)
126-
@common.SkipIfNoCorstone320
127-
def test_eq_u85_BI_on_fvp(test_module):
141+
@common.XfailIfNoCorstone320
142+
def test_eq_u85_BI(test_module):
128143
pipeline = EthosU85PipelineBI[input_t](
129144
test_module,
130145
test_module.get_inputs(),
131-
aten_op,
132-
exir_op,
146+
Equal.aten_op_BI,
147+
Equal.exir_op,
133148
run_on_fvp=True,
134-
use_to_edge_transform_and_lower=True,
135149
)
136150
pipeline.run()

backends/transforms/replace_scalar_with_tensor.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ class ReplaceScalarWithTensorArgPass(ExportPass):
2626
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
2727
exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor,
2828
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
29+
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
2930
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3031
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
3132
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
3233
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
3334
torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor,
3435
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
36+
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
3537
}
3638

3739
def get_replacement(self, op, args, kwargs, meta):

0 commit comments

Comments
 (0)