|
1 | 1 | # Copyright 2024-2025 Arm Limited and/or its affiliates.
|
2 |
| -# All rights reserved. |
3 | 2 | #
|
4 | 3 | # This source code is licensed under the BSD-style license found in the
|
5 | 4 | # LICENSE file in the root directory of this source tree.
|
@@ -37,14 +36,14 @@ class Expand(torch.nn.Module):
|
37 | 36 | # (input tensor, multiples)
|
38 | 37 | test_parameters = [
|
39 | 38 | (torch.rand(1), (2,)),
|
40 |
| - (torch.randn(1, 4), (1, -1)), |
41 | 39 | (torch.randn(1), (2, 2, 4)),
|
42 | 40 | (torch.randn(1, 1, 1, 5), (1, 4, -1, -1)),
|
43 |
| - (torch.randn(1, 1, 192), (1, -1, -1)), |
44 | 41 | (torch.randn(1, 1), (1, 2, 2, 4)),
|
45 | 42 | (torch.randn(1, 1), (2, 2, 2, 4)),
|
46 | 43 | (torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)),
|
47 | 44 | (torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
|
| 45 | + (torch.randn(1, 4), (1, -1)), |
| 46 | + (torch.randn(1, 1, 192), (1, -1, -1)), |
48 | 47 | ]
|
49 | 48 |
|
50 | 49 | def forward(self, x: torch.Tensor, m: Sequence):
|
@@ -117,34 +116,52 @@ def test_expand_tosa_MI(self, test_input, multiples):
|
117 | 116 | def test_expand_tosa_BI(self, test_input, multiples):
|
118 | 117 | self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))
|
119 | 118 |
|
120 |
| - @parameterized.expand(Expand.test_parameters[:-3]) |
| 119 | + @parameterized.expand(Expand.test_parameters[:-5]) |
121 | 120 | @pytest.mark.corstone_fvp
|
122 | 121 | def test_expand_u55_BI(self, test_input, multiples):
|
123 | 122 | self._test_expand_ethosu_BI_pipeline(
|
124 | 123 | common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
|
125 | 124 | )
|
126 | 125 |
|
127 | 126 | # MLETORCH-629: Expand does not work on FVP with batch>1
|
128 |
| - @parameterized.expand(Expand.test_parameters[-3:]) |
| 127 | + @parameterized.expand(Expand.test_parameters[-5:-2]) |
129 | 128 | @pytest.mark.corstone_fvp
|
130 | 129 | @conftest.expectedFailureOnFVP
|
| 130 | + def test_expand_u55_BI_xfails_on_fvp(self, test_input, multiples): |
| 131 | + self._test_expand_ethosu_BI_pipeline( |
| 132 | + common.get_u55_compile_spec(), self.Expand(), (test_input, multiples) |
| 133 | + ) |
| 134 | + |
| 135 | + @parameterized.expand(Expand.test_parameters[-2:]) |
| 136 | + @pytest.mark.xfail( |
| 137 | + reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs" |
| 138 | + ) |
131 | 139 | def test_expand_u55_BI_xfails(self, test_input, multiples):
|
132 | 140 | self._test_expand_ethosu_BI_pipeline(
|
133 | 141 | common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
|
134 | 142 | )
|
135 | 143 |
|
136 |
| - @parameterized.expand(Expand.test_parameters[:-3]) |
| 144 | + @parameterized.expand(Expand.test_parameters[:-5]) |
137 | 145 | @pytest.mark.corstone_fvp
|
138 | 146 | def test_expand_u85_BI(self, test_input, multiples):
|
139 | 147 | self._test_expand_ethosu_BI_pipeline(
|
140 | 148 | common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
|
141 | 149 | )
|
142 | 150 |
|
143 | 151 | # MLETORCH-629: Expand does not work on FVP with batch>1
|
144 |
| - @parameterized.expand(Expand.test_parameters[-3:]) |
| 152 | + @parameterized.expand(Expand.test_parameters[-5:-2]) |
145 | 153 | @pytest.mark.corstone_fvp
|
146 | 154 | @conftest.expectedFailureOnFVP
|
147 |
| - def test_expand_u85_BI_xfails(self, test_input, multiples): |
| 155 | + def test_expand_u85_BI_xfails_on_fvp(self, test_input, multiples): |
| 156 | + self._test_expand_ethosu_BI_pipeline( |
| 157 | + common.get_u85_compile_spec(), self.Expand(), (test_input, multiples) |
| 158 | + ) |
| 159 | + |
| 160 | + @parameterized.expand(Expand.test_parameters[-2:]) |
| 161 | + @pytest.mark.xfail( |
| 162 | + reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs" |
| 163 | + ) |
| 164 | + def test_expand_u85_xfails(self, test_input, multiples): |
148 | 165 | self._test_expand_ethosu_BI_pipeline(
|
149 | 166 | common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
|
150 | 167 | )
|
0 commit comments