Skip to content

Commit b2e1162

Browse files
committed
Support non-default dim orders in pybindings + aten bridge (#9736)
Summary: Now that we have better dim order support, we can pass inputs into models in non-default dim orders, such as channels last. This works in the runtime, but the python layer current asserts that input tensors are in default dim order. This PR relaxes this restriction to allow channels-last input tensors. Reviewed By: kimishpatel Differential Revision: D71716100
1 parent 69cc7fa commit b2e1162

File tree

4 files changed

+221
-19
lines changed

4 files changed

+221
-19
lines changed

extension/aten_util/aten_bridge.cpp

+19-8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/extension/aten_util/aten_bridge.h>
1010

11+
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
12+
#include <executorch/runtime/core/exec_aten/util/tensor_dimension_limit.h>
1113
#include <executorch/runtime/platform/assert.h>
1214
#include <cstring>
1315

@@ -40,7 +42,12 @@ ET_CHECK_MSG(
4042
ssize_t(a.size(i)),
4143
ssize_t(b.size(i)));
4244
}
43-
// check strides
45+
// check strides and dim order
46+
std::array<exec_aten::StridesType, executorch::runtime::kTensorDimensionLimit>
47+
expected_strides{};
48+
runtime::dim_order_to_stride_nocheck(
49+
b.sizes().data(), b.dim_order().data(), b.dim(), expected_strides.data());
50+
4451
for (size_t i = 0, dims = a.dim(); i < dims; ++i) {
4552
// Dont match strides if the size is 1.
4653
// Why? Because tensor is non-contig only if
@@ -52,6 +59,12 @@ ET_CHECK_MSG(
5259
i,
5360
ssize_t(a.stride(i)),
5461
ssize_t(b.strides()[i]));
62+
ET_CHECK_MSG(
63+
(b.size(i) == 1 || (b.strides()[i] == expected_strides[i])),
64+
"Strides don't match dim order at index:%zd, stride: %zd != expected %zd",
65+
i,
66+
ssize_t(a.stride(i)),
67+
ssize_t(expected_strides[i]));
5568
}
5669
// check dtype
5770
ET_CHECK_MSG(
@@ -109,13 +122,11 @@ c10::ScalarType executorch_to_torch_scalar_type(
109122
void alias_etensor_to_attensor(
110123
at::Tensor& aten_tensor,
111124
torch::executor::Tensor& mutable_et) {
112-
// TODO(kimishpatel): contiguous according to memformat
113-
// Right now we assume everything is channels first contiguous
114-
// Note that input tensor must be contiguous for us to alias.
115-
// Mixing aliasing and copying is dangerous since if we aliased
116-
// the instance of mutatble_et to aten_tensor in the previous call,
117-
// then in the next call copying will not be the correct behavior.
118-
ET_CHECK_MSG(aten_tensor.is_contiguous(), "Input tensor must be contiguous");
125+
ET_CHECK_MSG(
126+
aten_tensor.is_contiguous() ||
127+
aten_tensor.is_contiguous(at::MemoryFormat::ChannelsLast),
128+
"Input tensor must have contiguous or channels last memory format");
129+
119130
check_tensor_meta(aten_tensor, mutable_et);
120131
mutable_et.unsafeGetTensorImpl()->set_data(aten_tensor.mutable_data_ptr());
121132
}

extension/aten_util/test/aten_bridge_test.cpp

+100
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,103 @@ TEST(ATenBridgeTest, AliasTensorPtrToATenTensor) {
154154
alias_etensor_to_attensor(at_tensor, *et_tensor_ptr);
155155
EXPECT_EQ(at_tensor.const_data_ptr(), et_tensor_ptr->const_data_ptr());
156156
}
157+
158+
TEST(ATenBridgeTest, AliasATTensorToETensorChannelsLast) {
159+
auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast);
160+
std::vector<Tensor::SizesType> sizes(
161+
at_tensor.sizes().begin(), at_tensor.sizes().end());
162+
std::vector<Tensor::DimOrderType> dim_order = {0, 2, 3, 1};
163+
std::vector<Tensor::StridesType> strides(
164+
at_tensor.strides().begin(), at_tensor.strides().end());
165+
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
166+
std::vector<uint8_t> etensor_data(at_tensor.nbytes());
167+
torch::executor::TensorImpl tensor_impl(
168+
dtype,
169+
at_tensor.dim(),
170+
sizes.data(),
171+
etensor_data.data(),
172+
dim_order.data(),
173+
strides.data());
174+
torch::executor::Tensor etensor(&tensor_impl);
175+
auto aliased_at_tensor = alias_attensor_to_etensor(etensor);
176+
EXPECT_EQ(aliased_at_tensor.const_data_ptr(), etensor_data.data());
177+
}
178+
179+
TEST(ATenBridgeTest, AliasATTensorToETensorFailDimOrder) {
180+
auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast);
181+
std::vector<Tensor::SizesType> sizes(
182+
at_tensor.sizes().begin(), at_tensor.sizes().end());
183+
std::vector<Tensor::DimOrderType> dim_order = {0, 1, 2, 3};
184+
std::vector<Tensor::StridesType> strides(
185+
at_tensor.strides().begin(), at_tensor.strides().end());
186+
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
187+
std::vector<uint8_t> etensor_data(at_tensor.nbytes());
188+
torch::executor::TensorImpl tensor_impl(
189+
dtype,
190+
at_tensor.dim(),
191+
sizes.data(),
192+
etensor_data.data(),
193+
dim_order.data(),
194+
strides.data());
195+
torch::executor::Tensor etensor(&tensor_impl);
196+
ET_EXPECT_DEATH(alias_attensor_to_etensor(etensor), "");
197+
}
198+
199+
TEST(ATenBridgeTest, AliasETensorToATenTensorChannelsLast) {
200+
auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast);
201+
std::vector<Tensor::SizesType> sizes(
202+
at_tensor.sizes().begin(), at_tensor.sizes().end());
203+
std::vector<Tensor::DimOrderType> dim_order = {0, 2, 3, 1};
204+
std::vector<Tensor::StridesType> strides(
205+
at_tensor.strides().begin(), at_tensor.strides().end());
206+
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
207+
torch::executor::TensorImpl tensor_impl(
208+
dtype,
209+
at_tensor.dim(),
210+
sizes.data(),
211+
nullptr,
212+
dim_order.data(),
213+
strides.data());
214+
torch::executor::Tensor etensor(&tensor_impl);
215+
alias_etensor_to_attensor(at_tensor, etensor);
216+
EXPECT_EQ(at_tensor.const_data_ptr(), etensor.const_data_ptr());
217+
}
218+
219+
TEST(ATenBridgeTest, AliasETensorToATenTensorFailDimOrder) {
220+
auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast);
221+
std::vector<Tensor::SizesType> sizes(
222+
at_tensor.sizes().begin(), at_tensor.sizes().end());
223+
std::vector<Tensor::DimOrderType> dim_order = {0, 1, 2, 3};
224+
std::vector<Tensor::StridesType> strides(
225+
at_tensor.strides().begin(), at_tensor.strides().end());
226+
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
227+
torch::executor::TensorImpl tensor_impl(
228+
dtype,
229+
at_tensor.dim(),
230+
sizes.data(),
231+
nullptr,
232+
dim_order.data(),
233+
strides.data());
234+
torch::executor::Tensor etensor(&tensor_impl);
235+
ET_EXPECT_DEATH(alias_etensor_to_attensor(at_tensor, etensor), "");
236+
}
237+
238+
TEST(ATenBridgeTest, AliasETensorToATenTensorFailUnsupportedDimOrder) {
239+
auto at_tensor =
240+
at::randn({1, 2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast3d);
241+
std::vector<Tensor::SizesType> sizes(
242+
at_tensor.sizes().begin(), at_tensor.sizes().end());
243+
std::vector<Tensor::DimOrderType> dim_order = {0, 2, 3, 4, 1};
244+
std::vector<Tensor::StridesType> strides(
245+
at_tensor.strides().begin(), at_tensor.strides().end());
246+
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
247+
torch::executor::TensorImpl tensor_impl(
248+
dtype,
249+
at_tensor.dim(),
250+
sizes.data(),
251+
nullptr,
252+
dim_order.data(),
253+
strides.data());
254+
torch::executor::Tensor etensor(&tensor_impl);
255+
ET_EXPECT_DEATH(alias_etensor_to_attensor(at_tensor, etensor), "");
256+
}

extension/pybindings/pybindings.cpp

+14-10
Original file line numberDiff line numberDiff line change
@@ -703,13 +703,6 @@ struct PyModule final {
703703
const std::string& type_str = py::str(python_input.get_type());
704704
if (type_str == "<class 'torch.Tensor'>") {
705705
auto at_tensor = python_input.cast<at::Tensor>();
706-
// alias_etensor_to_attensor will assert on this later, so to better
707-
// propogate up to python we check early and throw an exception.
708-
if (!at_tensor.is_contiguous()) {
709-
auto error_msg = "Input " + std::to_string(i) + "for method " +
710-
method_name + " is not contiguous.";
711-
throw std::runtime_error(error_msg);
712-
}
713706

714707
#ifdef USE_ATEN_LIB
715708
EValue evalue(at_tensor);
@@ -725,10 +718,21 @@ struct PyModule final {
725718
input_strides.emplace_back(
726719
at_tensor.strides().begin(), at_tensor.strides().end());
727720

728-
// Only works for MemoryFormat::Contiguous inputs
721+
// Only works for MemoryFormat::Contiguous or MemoryFormat::ChannelsLast
722+
// inputs
729723
std::vector<torch::executor::Tensor::DimOrderType> dim_order;
730-
for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) {
731-
dim_order.push_back(cur_dim);
724+
if (at_tensor.is_contiguous()) {
725+
for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) {
726+
dim_order.push_back(cur_dim);
727+
}
728+
} else if (
729+
at_tensor.is_contiguous(at::MemoryFormat::ChannelsLast) &&
730+
at_tensor.dim() == 4) {
731+
dim_order = decltype(dim_order)({0, 2, 3, 1});
732+
} else {
733+
auto error_msg = "Input " + std::to_string(i) + "for method " +
734+
method_name + " should be contiguous or channels-last.";
735+
throw std::runtime_error(error_msg);
732736
}
733737
input_dim_order.push_back(std::move(dim_order));
734738
input_tensors.emplace_back(

extension/pybindings/test/make_test.py

+88-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,40 @@ def get_inputs(self):
3232
return (torch.ones(2, 2), torch.ones(2, 2))
3333

3434

35+
class ModuleChannelsLast(torch.nn.Module):
36+
"""The module to serialize and execute."""
37+
38+
def forward(self, x):
39+
return torch.nn.functional.interpolate(
40+
x,
41+
scale_factor=2,
42+
mode="nearest",
43+
)
44+
45+
def get_methods_to_export(self):
46+
return ("forward",)
47+
48+
def get_inputs(self):
49+
return (torch.ones(1, 2, 3, 4).to(memory_format=torch.channels_last),)
50+
51+
52+
class ModuleChannelsLastInDefaultOut(torch.nn.Module):
53+
"""The module to serialize and execute."""
54+
55+
def forward(self, x):
56+
return torch.nn.functional.interpolate(
57+
x,
58+
scale_factor=2,
59+
mode="nearest",
60+
).to(memory_format=torch.contiguous_format)
61+
62+
def get_methods_to_export(self):
63+
return ("forward",)
64+
65+
def get_inputs(self):
66+
return (torch.ones(1, 2, 3, 4).to(memory_format=torch.channels_last),)
67+
68+
3569
class ModuleMulti(torch.nn.Module):
3670
"""The module to serialize and execute."""
3771

@@ -298,11 +332,61 @@ def test_constant_output_not_memory_planned(tester):
298332
# The test module adds the input to torch.ones(2,2), so its output should be the same
299333
# as adding them directly.
300334
expected = torch.ones(2, 2) + torch.ones(2, 2)
301-
tester.assertEqual(str(expected), str(executorch_output[0]))
335+
tester.assertTrue(torch.allclose(expected, executorch_output[0]))
302336

303337
# The test module returns the state. Check that its value is correct.
304338
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))
305339

340+
def test_channels_last(tester) -> None:
341+
# Create an ExecuTorch program from ModuleChannelsLast.
342+
model = ModuleChannelsLast()
343+
exported_program, inputs = create_program(model)
344+
345+
# Use pybindings to load and execute the program.
346+
executorch_module = load_fn(exported_program.buffer)
347+
# Inovke the callable on executorch_module instead of calling module.forward.
348+
# Use only one input to test this case.
349+
executorch_output = executorch_module(inputs[0])[0]
350+
351+
# The test module adds the two inputs, so its output should be the same
352+
# as adding them directly.
353+
expected = model(inputs[0])
354+
tester.assertTrue(torch.allclose(expected, executorch_output))
355+
356+
def test_unsupported_dim_order(tester) -> None:
357+
"""
358+
Verify that the pybind layer rejects unsupported dim orders.
359+
"""
360+
361+
# Create an ExecuTorch program from ModuleChannelsLast.
362+
model = ModuleChannelsLast()
363+
exported_program, inputs = create_program(model)
364+
inputs = (
365+
torch.randn(1, 2, 3, 4, 5).to(memory_format=torch.channels_last_3d),
366+
)
367+
368+
# Use pybindings to load and execute the program.
369+
executorch_module = load_fn(exported_program.buffer)
370+
371+
# We expect execution to error because of the invalid input dim order.
372+
tester.assertRaises(RuntimeError, executorch_module, inputs[0])
373+
374+
def test_channels_last_in_default_out(tester) -> None:
375+
# Create an ExecuTorch program from ModuleChannelsLastInDefaultOut.
376+
model = ModuleChannelsLastInDefaultOut()
377+
exported_program, inputs = create_program(model)
378+
379+
# Use pybindings to load and execute the program.
380+
executorch_module = load_fn(exported_program.buffer)
381+
# Inovke the callable on executorch_module instead of calling module.forward.
382+
# Use only one input to test this case.
383+
executorch_output = executorch_module(inputs[0])[0]
384+
385+
# The test module adds the two inputs, so its output should be the same
386+
# as adding them directly.
387+
expected = model(inputs[0])
388+
tester.assertTrue(torch.allclose(expected, executorch_output))
389+
306390
def test_method_meta(tester) -> None:
307391
exported_program, inputs = create_program(ModuleAdd())
308392

@@ -388,6 +472,9 @@ def test_verification_config(tester) -> None:
388472
test_module_single_input(tester)
389473
test_stderr_redirect(tester)
390474
test_quantized_ops(tester)
475+
test_channels_last(tester)
476+
test_channels_last_in_default_out(tester)
477+
test_unsupported_dim_order(tester)
391478
test_constant_output_not_memory_planned(tester)
392479
test_method_meta(tester)
393480
test_bad_name(tester)

0 commit comments

Comments
 (0)