9
9
10
10
from typing import Tuple
11
11
12
+ import pytest
13
+
12
14
import torch
13
15
14
- from executorch .backends .arm .test import common
16
+ from executorch .backends .arm .test import common , conftest
15
17
from executorch .backends .arm .test .tester .arm_tester import ArmTester
16
18
from executorch .exir .backend .compile_spec_schema import CompileSpec
17
19
from parameterized import parameterized
@@ -40,7 +42,7 @@ def forward(self, x):
40
42
def _test_tanh_tosa_MI_pipeline (
41
43
self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
42
44
):
43
- (
45
+ tester = (
44
46
ArmTester (
45
47
module ,
46
48
example_inputs = test_data ,
@@ -54,11 +56,13 @@ def _test_tanh_tosa_MI_pipeline(
54
56
.check_not (["executorch_exir_dialects_edge__ops_aten_tanh_default" ])
55
57
.check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
56
58
.to_executorch ()
57
- .run_method_and_compare_outputs (inputs = test_data )
58
59
)
59
60
61
+ if conftest .is_option_enabled ("tosa_ref_model" ):
62
+ tester .run_method_and_compare_outputs (inputs = test_data )
63
+
60
64
def _test_tanh_tosa_BI_pipeline (self , module : torch .nn .Module , test_data : Tuple ):
61
- (
65
+ tester = (
62
66
ArmTester (
63
67
module ,
64
68
example_inputs = test_data ,
@@ -73,9 +77,11 @@ def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple)
73
77
.check_not (["executorch_exir_dialects_edge__ops_aten_tanh_default" ])
74
78
.check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
75
79
.to_executorch ()
76
- .run_method_and_compare_outputs (inputs = test_data )
77
80
)
78
81
82
+ if conftest .is_option_enabled ("tosa_ref_model" ):
83
+ tester .run_method_and_compare_outputs (inputs = test_data )
84
+
79
85
def _test_tanh_tosa_ethos_BI_pipeline (
80
86
self ,
81
87
compile_spec : list [CompileSpec ],
@@ -114,6 +120,7 @@ def _test_tanh_tosa_u85_BI_pipeline(
114
120
)
115
121
116
122
@parameterized .expand (test_data_suite )
123
+ @pytest .mark .tosa_ref_model
117
124
def test_tanh_tosa_MI (
118
125
self ,
119
126
test_name : str ,
@@ -122,6 +129,7 @@ def test_tanh_tosa_MI(
122
129
self ._test_tanh_tosa_MI_pipeline (self .Tanh (), (test_data ,))
123
130
124
131
@parameterized .expand (test_data_suite )
132
+ @pytest .mark .tosa_ref_model
125
133
def test_tanh_tosa_BI (self , test_name : str , test_data : torch .Tensor ):
126
134
self ._test_tanh_tosa_BI_pipeline (self .Tanh (), (test_data ,))
127
135
0 commit comments