@@ -154,3 +154,103 @@ TEST(ATenBridgeTest, AliasTensorPtrToATenTensor) {
154
154
alias_etensor_to_attensor (at_tensor, *et_tensor_ptr);
155
155
EXPECT_EQ (at_tensor.const_data_ptr (), et_tensor_ptr->const_data_ptr ());
156
156
}
157
+
158
+ TEST (ATenBridgeTest, AliasATTensorToETensorChannelsLast) {
159
+ auto at_tensor = at::randn ({2 , 3 , 4 , 5 }).to (at::MemoryFormat::ChannelsLast);
160
+ std::vector<Tensor::SizesType> sizes (
161
+ at_tensor.sizes ().begin (), at_tensor.sizes ().end ());
162
+ std::vector<Tensor::DimOrderType> dim_order = {0 , 2 , 3 , 1 };
163
+ std::vector<Tensor::StridesType> strides (
164
+ at_tensor.strides ().begin (), at_tensor.strides ().end ());
165
+ auto dtype = torchToExecuTorchScalarType (at_tensor.options ().dtype ());
166
+ std::vector<uint8_t > etensor_data (at_tensor.nbytes ());
167
+ torch::executor::TensorImpl tensor_impl (
168
+ dtype,
169
+ at_tensor.dim (),
170
+ sizes.data (),
171
+ etensor_data.data (),
172
+ dim_order.data (),
173
+ strides.data ());
174
+ torch::executor::Tensor etensor (&tensor_impl);
175
+ auto aliased_at_tensor = alias_attensor_to_etensor (etensor);
176
+ EXPECT_EQ (aliased_at_tensor.const_data_ptr (), etensor_data.data ());
177
+ }
178
+
179
+ TEST (ATenBridgeTest, AliasATTensorToETensorFailDimOrder) {
180
+ auto at_tensor = at::randn ({2 , 3 , 4 , 5 }).to (at::MemoryFormat::ChannelsLast);
181
+ std::vector<Tensor::SizesType> sizes (
182
+ at_tensor.sizes ().begin (), at_tensor.sizes ().end ());
183
+ std::vector<Tensor::DimOrderType> dim_order = {0 , 1 , 2 , 3 };
184
+ std::vector<Tensor::StridesType> strides (
185
+ at_tensor.strides ().begin (), at_tensor.strides ().end ());
186
+ auto dtype = torchToExecuTorchScalarType (at_tensor.options ().dtype ());
187
+ std::vector<uint8_t > etensor_data (at_tensor.nbytes ());
188
+ torch::executor::TensorImpl tensor_impl (
189
+ dtype,
190
+ at_tensor.dim (),
191
+ sizes.data (),
192
+ etensor_data.data (),
193
+ dim_order.data (),
194
+ strides.data ());
195
+ torch::executor::Tensor etensor (&tensor_impl);
196
+ ET_EXPECT_DEATH (alias_attensor_to_etensor (etensor), " " );
197
+ }
198
+
199
+ TEST (ATenBridgeTest, AliasETensorToATenTensorChannelsLast) {
200
+ auto at_tensor = at::randn ({2 , 3 , 4 , 5 }).to (at::MemoryFormat::ChannelsLast);
201
+ std::vector<Tensor::SizesType> sizes (
202
+ at_tensor.sizes ().begin (), at_tensor.sizes ().end ());
203
+ std::vector<Tensor::DimOrderType> dim_order = {0 , 2 , 3 , 1 };
204
+ std::vector<Tensor::StridesType> strides (
205
+ at_tensor.strides ().begin (), at_tensor.strides ().end ());
206
+ auto dtype = torchToExecuTorchScalarType (at_tensor.options ().dtype ());
207
+ torch::executor::TensorImpl tensor_impl (
208
+ dtype,
209
+ at_tensor.dim (),
210
+ sizes.data (),
211
+ nullptr ,
212
+ dim_order.data (),
213
+ strides.data ());
214
+ torch::executor::Tensor etensor (&tensor_impl);
215
+ alias_etensor_to_attensor (at_tensor, etensor);
216
+ EXPECT_EQ (at_tensor.const_data_ptr (), etensor.const_data_ptr ());
217
+ }
218
+
219
+ TEST (ATenBridgeTest, AliasETensorToATenTensorFailDimOrder) {
220
+ auto at_tensor = at::randn ({2 , 3 , 4 , 5 }).to (at::MemoryFormat::ChannelsLast);
221
+ std::vector<Tensor::SizesType> sizes (
222
+ at_tensor.sizes ().begin (), at_tensor.sizes ().end ());
223
+ std::vector<Tensor::DimOrderType> dim_order = {0 , 1 , 2 , 3 };
224
+ std::vector<Tensor::StridesType> strides (
225
+ at_tensor.strides ().begin (), at_tensor.strides ().end ());
226
+ auto dtype = torchToExecuTorchScalarType (at_tensor.options ().dtype ());
227
+ torch::executor::TensorImpl tensor_impl (
228
+ dtype,
229
+ at_tensor.dim (),
230
+ sizes.data (),
231
+ nullptr ,
232
+ dim_order.data (),
233
+ strides.data ());
234
+ torch::executor::Tensor etensor (&tensor_impl);
235
+ ET_EXPECT_DEATH (alias_etensor_to_attensor (at_tensor, etensor), " " );
236
+ }
237
+
238
+ TEST (ATenBridgeTest, AliasETensorToATenTensorFailUnsupportedDimOrder) {
239
+ auto at_tensor =
240
+ at::randn ({1 , 2 , 3 , 4 , 5 }).to (at::MemoryFormat::ChannelsLast3d);
241
+ std::vector<Tensor::SizesType> sizes (
242
+ at_tensor.sizes ().begin (), at_tensor.sizes ().end ());
243
+ std::vector<Tensor::DimOrderType> dim_order = {0 , 2 , 3 , 4 , 1 };
244
+ std::vector<Tensor::StridesType> strides (
245
+ at_tensor.strides ().begin (), at_tensor.strides ().end ());
246
+ auto dtype = torchToExecuTorchScalarType (at_tensor.options ().dtype ());
247
+ torch::executor::TensorImpl tensor_impl (
248
+ dtype,
249
+ at_tensor.dim (),
250
+ sizes.data (),
251
+ nullptr ,
252
+ dim_order.data (),
253
+ strides.data ());
254
+ torch::executor::Tensor etensor (&tensor_impl);
255
+ ET_EXPECT_DEATH (alias_etensor_to_attensor (at_tensor, etensor), " " );
256
+ }
0 commit comments