5
5
6
6
from typing import Tuple
7
7
8
- import pytest
9
8
import torch
10
9
from executorch .backends .arm .test import common
11
10
16
15
TosaPipelineMI ,
17
16
)
18
17
19
- aten_op = "torch.ops.aten.eq.Tensor"
20
- exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
21
18
22
19
input_t = Tuple [torch .Tensor ]
23
20
24
21
25
22
class Equal (torch .nn .Module ):
23
+ aten_op_BI = "torch.ops.aten.eq.Tensor"
24
+ aten_op_MI = "torch.ops.aten.eq.Scalar"
25
+ exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor"
26
+
26
27
def __init__ (self , input , other ):
27
28
super ().__init__ ()
28
29
self .input_ = input
@@ -31,106 +32,119 @@ def __init__(self, input, other):
31
32
def forward (
32
33
self ,
33
34
input_ : torch .Tensor ,
34
- other_ : torch .Tensor ,
35
+ other_ : torch .Tensor | int | float ,
35
36
):
36
37
return input_ == other_
37
38
38
39
def get_inputs (self ):
39
40
return (self .input_ , self .other_ )
40
41
41
42
42
- op_eq_rank1_ones = Equal (
43
+ op_eq_tensor_rank1_ones = Equal (
43
44
torch .ones (5 ),
44
45
torch .ones (5 ),
45
46
)
46
- op_eq_rank2_rand = Equal (
47
+ op_eq_tensor_rank2_rand = Equal (
47
48
torch .rand (4 , 5 ),
48
49
torch .rand (1 , 5 ),
49
50
)
50
- op_eq_rank3_randn = Equal (
51
+ op_eq_tensor_rank3_randn = Equal (
51
52
torch .randn (10 , 5 , 2 ),
52
53
torch .randn (10 , 5 , 2 ),
53
54
)
54
- op_eq_rank4_randn = Equal (
55
+ op_eq_tensor_rank4_randn = Equal (
55
56
torch .randn (3 , 2 , 2 , 2 ),
56
57
torch .randn (3 , 2 , 2 , 2 ),
57
58
)
58
59
59
- test_data_common = {
60
- "eq_rank1_ones" : op_eq_rank1_ones ,
61
- "eq_rank2_rand" : op_eq_rank2_rand ,
62
- "eq_rank3_randn" : op_eq_rank3_randn ,
63
- "eq_rank4_randn" : op_eq_rank4_randn ,
60
+ op_eq_scalar_rank1_ones = Equal (torch .ones (5 ), 1.0 )
61
+ op_eq_scalar_rank2_rand = Equal (torch .rand (4 , 5 ), 0.2 )
62
+ op_eq_scalar_rank3_randn = Equal (torch .randn (10 , 5 , 2 ), - 0.1 )
63
+ op_eq_scalar_rank4_randn = Equal (torch .randn (3 , 2 , 2 , 2 ), 0.3 )
64
+
65
+ test_data_tensor = {
66
+ "eq_tensor_rank1_ones" : op_eq_tensor_rank1_ones ,
67
+ "eq_tensor_rank2_rand" : op_eq_tensor_rank2_rand ,
68
+ "eq_tensor_rank3_randn" : op_eq_tensor_rank3_randn ,
69
+ "eq_tensor_rank4_randn" : op_eq_tensor_rank4_randn ,
64
70
}
65
71
72
+ test_data_scalar = {
73
+ "eq_scalar_rank1_ones" : op_eq_scalar_rank1_ones ,
74
+ "eq_scalar_rank2_rand" : op_eq_scalar_rank2_rand ,
75
+ "eq_scalar_rank3_randn" : op_eq_scalar_rank3_randn ,
76
+ "eq_scalar_rank4_randn" : op_eq_scalar_rank4_randn ,
77
+ }
78
+
79
+
80
+ @common .parametrize ("test_module" , test_data_tensor )
81
+ def test_eq_tensor_tosa_MI (test_module ):
82
+ pipeline = TosaPipelineMI [input_t ](
83
+ test_module , test_module .get_inputs (), Equal .aten_op_BI , Equal .exir_op
84
+ )
85
+ pipeline .run ()
66
86
67
- @common .parametrize ("test_module" , test_data_common )
68
- def test_eq_tosa_MI (test_module ):
87
+
88
+ @common .parametrize ("test_module" , test_data_scalar )
89
+ def test_eq_scalar_tosa_MI (test_module ):
69
90
pipeline = TosaPipelineMI [input_t ](
70
- test_module , test_module .get_inputs (), aten_op , exir_op
91
+ test_module ,
92
+ test_module .get_inputs (),
93
+ Equal .aten_op_MI ,
94
+ Equal .exir_op ,
71
95
)
72
96
pipeline .run ()
73
97
74
98
75
- @common .parametrize ("test_module" , test_data_common )
99
+ @common .parametrize ("test_module" , test_data_tensor | test_data_scalar )
76
100
def test_eq_tosa_BI (test_module ):
77
101
pipeline = TosaPipelineBI [input_t ](
78
- test_module , test_module .get_inputs (), aten_op , exir_op
102
+ test_module , test_module .get_inputs (), Equal . aten_op_BI , Equal . exir_op
79
103
)
80
104
pipeline .run ()
81
105
82
106
83
- @common .parametrize ("test_module" , test_data_common )
84
- def test_eq_u55_BI (test_module ):
107
+ @common .parametrize ("test_module" , test_data_tensor )
108
+ @common .XfailIfNoCorstone300
109
+ def test_eq_tensor_u55_BI (test_module ):
85
110
# EQUAL is not supported on U55.
86
111
pipeline = OpNotSupportedPipeline [input_t ](
87
112
test_module ,
88
113
test_module .get_inputs (),
89
114
"TOSA-0.80+BI+u55" ,
90
- {exir_op : 1 },
91
- )
92
- pipeline .run ()
93
-
94
-
95
- @common .parametrize ("test_module" , test_data_common )
96
- def test_eq_u85_BI (test_module ):
97
- pipeline = EthosU85PipelineBI [input_t ](
98
- test_module ,
99
- test_module .get_inputs (),
100
- aten_op ,
101
- exir_op ,
102
- run_on_fvp = False ,
103
- use_to_edge_transform_and_lower = True ,
115
+ {Equal .exir_op : 1 },
104
116
)
105
117
pipeline .run ()
106
118
107
119
108
- @common .parametrize ("test_module" , test_data_common )
109
- @pytest . mark . skip ( reason = "The same as test_eq_u55_BI" )
110
- def test_eq_u55_BI_on_fvp (test_module ):
120
+ @common .parametrize ("test_module" , test_data_scalar )
121
+ @common . XfailIfNoCorstone300
122
+ def test_eq_scalar_u55_BI (test_module ):
111
123
# EQUAL is not supported on U55.
112
124
pipeline = OpNotSupportedPipeline [input_t ](
113
125
test_module ,
114
126
test_module .get_inputs (),
115
127
"TOSA-0.80+BI+u55" ,
116
- {exir_op : 1 },
128
+ {Equal .exir_op : 1 },
129
+ n_expected_delegates = 1 ,
117
130
)
118
131
pipeline .run ()
119
132
120
133
121
134
@common .parametrize (
122
135
"test_module" ,
123
- test_data_common ,
124
- xfails = {"eq_rank4_randn" : "4D fails because boolean Tensors can't be subtracted" },
136
+ test_data_tensor | test_data_scalar ,
137
+ xfails = {
138
+ "eq_tensor_rank4_randn" : "4D fails because boolean Tensors can't be subtracted" ,
139
+ },
125
140
)
126
- @common .SkipIfNoCorstone320
127
- def test_eq_u85_BI_on_fvp (test_module ):
141
+ @common .XfailIfNoCorstone320
142
+ def test_eq_u85_BI (test_module ):
128
143
pipeline = EthosU85PipelineBI [input_t ](
129
144
test_module ,
130
145
test_module .get_inputs (),
131
- aten_op ,
132
- exir_op ,
146
+ Equal . aten_op_BI ,
147
+ Equal . exir_op ,
133
148
run_on_fvp = True ,
134
- use_to_edge_transform_and_lower = True ,
135
149
)
136
150
pipeline .run ()
0 commit comments