Skip to content

Commit 3410cfe

Browse files
authored
Buckify Sigmoid test
Differential Revision: D72940392 Pull Request resolved: #10224
1 parent 0555271 commit 3410cfe

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

backends/arm/test/ops/test_sigmoid.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
from typing import Tuple
1111

12+
import pytest
13+
1214
import torch
13-
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test import common, conftest
1416
from executorch.backends.arm.test.tester.arm_tester import ArmTester
1517
from executorch.exir.backend.compile_spec_schema import CompileSpec
1618
from parameterized import parameterized
@@ -63,7 +65,7 @@ def forward(self, x, y):
6365
def _test_sigmoid_tosa_MI_pipeline(
6466
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
6567
):
66-
(
68+
tester = (
6769
ArmTester(
6870
module,
6971
example_inputs=test_data,
@@ -77,11 +79,13 @@ def _test_sigmoid_tosa_MI_pipeline(
7779
.check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
7880
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
7981
.to_executorch()
80-
.run_method_and_compare_outputs(inputs=test_data)
8182
)
8283

84+
if conftest.is_option_enabled("tosa_ref_model"):
85+
tester.run_method_and_compare_outputs(inputs=test_data)
86+
8387
def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
84-
(
88+
tester = (
8589
ArmTester(
8690
module,
8791
example_inputs=test_data,
@@ -96,9 +100,11 @@ def _test_sigmoid_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tup
96100
.check_not(["executorch_exir_dialects_edge__ops_aten_sigmoid_default"])
97101
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
98102
.to_executorch()
99-
.run_method_and_compare_outputs(inputs=test_data)
100103
)
101104

105+
if conftest.is_option_enabled("tosa_ref_model"):
106+
tester.run_method_and_compare_outputs(inputs=test_data)
107+
102108
def _test_sigmoid_tosa_ethos_BI_pipeline(
103109
self,
104110
compile_spec: list[CompileSpec],
@@ -137,6 +143,7 @@ def _test_sigmoid_tosa_u85_BI_pipeline(
137143
)
138144

139145
@parameterized.expand(test_data_suite)
146+
@pytest.mark.tosa_ref_model
140147
def test_sigmoid_tosa_MI(
141148
self,
142149
test_name: str,
@@ -145,26 +152,33 @@ def test_sigmoid_tosa_MI(
145152
self._test_sigmoid_tosa_MI_pipeline(self.Sigmoid(), (test_data,))
146153

147154
@parameterized.expand(test_data_suite)
155+
@pytest.mark.tosa_ref_model
148156
def test_sigmoid_tosa_BI(self, test_name: str, test_data: torch.Tensor):
149157
self._test_sigmoid_tosa_BI_pipeline(self.Sigmoid(), (test_data,))
150158

159+
@pytest.mark.tosa_ref_model
151160
def test_add_sigmoid_tosa_MI(self):
152161
self._test_sigmoid_tosa_MI_pipeline(self.AddSigmoid(), (test_data_suite[0][1],))
153162

163+
@pytest.mark.tosa_ref_model
154164
def test_add_sigmoid_tosa_BI(self):
155165
self._test_sigmoid_tosa_BI_pipeline(self.AddSigmoid(), (test_data_suite[5][1],))
156166

167+
@pytest.mark.tosa_ref_model
157168
def test_sigmoid_add_tosa_MI(self):
158169
self._test_sigmoid_tosa_MI_pipeline(self.SigmoidAdd(), (test_data_suite[0][1],))
159170

171+
@pytest.mark.tosa_ref_model
160172
def test_sigmoid_add_tosa_BI(self):
161173
self._test_sigmoid_tosa_BI_pipeline(self.SigmoidAdd(), (test_data_suite[0][1],))
162174

175+
@pytest.mark.tosa_ref_model
163176
def test_sigmoid_add_sigmoid_tosa_MI(self):
164177
self._test_sigmoid_tosa_MI_pipeline(
165178
self.SigmoidAddSigmoid(), (test_data_suite[4][1], test_data_suite[3][1])
166179
)
167180

181+
@pytest.mark.tosa_ref_model
168182
def test_sigmoid_add_sigmoid_tosa_BI(self):
169183
self._test_sigmoid_tosa_BI_pipeline(
170184
self.SigmoidAddSigmoid(), (test_data_suite[4][1], test_data_suite[3][1])

backends/arm/test/targets.bzl

+5-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@ def define_arm_tests():
1212
test_files.remove("passes/test_ioquantization_pass.py")
1313

1414
# Operators
15-
test_files += ["ops/test_linear.py"]
16-
test_files += ["ops/test_slice.py"]
15+
test_files += [
16+
"ops/test_linear.py",
17+
"ops/test_slice.py",
18+
"ops/test_sigmoid.py",
19+
]
1720

1821
TESTS = {}
1922

0 commit comments

Comments
 (0)