Skip to content

Support non-default dim orders in pybindings + aten bridge #9736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions extension/aten_util/aten_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include <executorch/extension/aten_util/aten_bridge.h>

#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_dimension_limit.h>
#include <executorch/runtime/platform/assert.h>
#include <cstring>

Expand Down Expand Up @@ -40,7 +42,12 @@ ET_CHECK_MSG(
ssize_t(a.size(i)),
ssize_t(b.size(i)));
}
// check strides
// check strides and dim order
std::array<exec_aten::StridesType, executorch::runtime::kTensorDimensionLimit>
expected_strides{};
runtime::dim_order_to_stride_nocheck(
b.sizes().data(), b.dim_order().data(), b.dim(), expected_strides.data());

for (size_t i = 0, dims = a.dim(); i < dims; ++i) {
// Dont match strides if the size is 1.
// Why? Because tensor is non-contig only if
Expand All @@ -52,6 +59,12 @@ ET_CHECK_MSG(
i,
ssize_t(a.stride(i)),
ssize_t(b.strides()[i]));
ET_CHECK_MSG(
(b.size(i) == 1 || (b.strides()[i] == expected_strides[i])),
"Strides don't match dim order at index:%zd, stride: %zd != expected %zd",
i,
ssize_t(a.stride(i)),
ssize_t(expected_strides[i]));
}
// check dtype
ET_CHECK_MSG(
Expand Down Expand Up @@ -109,13 +122,11 @@ c10::ScalarType executorch_to_torch_scalar_type(
void alias_etensor_to_attensor(
at::Tensor& aten_tensor,
torch::executor::Tensor& mutable_et) {
// TODO(kimishpatel): contiguous according to memformat
// Right now we assume everything is channels first contiguous
// Note that input tensor must be contiguous for us to alias.
// Mixing aliasing and copying is dangerous since if we aliased
// the instance of mutatble_et to aten_tensor in the previous call,
// then in the next call copying will not be the correct behavior.
ET_CHECK_MSG(aten_tensor.is_contiguous(), "Input tensor must be contiguous");
ET_CHECK_MSG(
aten_tensor.is_contiguous() ||
aten_tensor.is_contiguous(at::MemoryFormat::ChannelsLast),
"Input tensor must have contiguous or channels last memory format");

check_tensor_meta(aten_tensor, mutable_et);
mutable_et.unsafeGetTensorImpl()->set_data(aten_tensor.mutable_data_ptr());
}
Expand Down
100 changes: 100 additions & 0 deletions extension/aten_util/test/aten_bridge_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,103 @@ TEST(ATenBridgeTest, AliasTensorPtrToATenTensor) {
alias_etensor_to_attensor(at_tensor, *et_tensor_ptr);
EXPECT_EQ(at_tensor.const_data_ptr(), et_tensor_ptr->const_data_ptr());
}

TEST(ATenBridgeTest, AliasATTensorToETensorChannelsLast) {
auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast);
std::vector<Tensor::SizesType> sizes(
at_tensor.sizes().begin(), at_tensor.sizes().end());
std::vector<Tensor::DimOrderType> dim_order = {0, 2, 3, 1};
std::vector<Tensor::StridesType> strides(
at_tensor.strides().begin(), at_tensor.strides().end());
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
std::vector<uint8_t> etensor_data(at_tensor.nbytes());
torch::executor::TensorImpl tensor_impl(
dtype,
at_tensor.dim(),
sizes.data(),
etensor_data.data(),
dim_order.data(),
strides.data());
torch::executor::Tensor etensor(&tensor_impl);
auto aliased_at_tensor = alias_attensor_to_etensor(etensor);
EXPECT_EQ(aliased_at_tensor.const_data_ptr(), etensor_data.data());
}

TEST(ATenBridgeTest, AliasATTensorToETensorFailDimOrder) {
auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast);
std::vector<Tensor::SizesType> sizes(
at_tensor.sizes().begin(), at_tensor.sizes().end());
std::vector<Tensor::DimOrderType> dim_order = {0, 1, 2, 3};
std::vector<Tensor::StridesType> strides(
at_tensor.strides().begin(), at_tensor.strides().end());
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
std::vector<uint8_t> etensor_data(at_tensor.nbytes());
torch::executor::TensorImpl tensor_impl(
dtype,
at_tensor.dim(),
sizes.data(),
etensor_data.data(),
dim_order.data(),
strides.data());
torch::executor::Tensor etensor(&tensor_impl);
ET_EXPECT_DEATH(alias_attensor_to_etensor(etensor), "");
}

TEST(ATenBridgeTest, AliasETensorToATenTensorChannelsLast) {
auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast);
std::vector<Tensor::SizesType> sizes(
at_tensor.sizes().begin(), at_tensor.sizes().end());
std::vector<Tensor::DimOrderType> dim_order = {0, 2, 3, 1};
std::vector<Tensor::StridesType> strides(
at_tensor.strides().begin(), at_tensor.strides().end());
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
torch::executor::TensorImpl tensor_impl(
dtype,
at_tensor.dim(),
sizes.data(),
nullptr,
dim_order.data(),
strides.data());
torch::executor::Tensor etensor(&tensor_impl);
alias_etensor_to_attensor(at_tensor, etensor);
EXPECT_EQ(at_tensor.const_data_ptr(), etensor.const_data_ptr());
}

TEST(ATenBridgeTest, AliasETensorToATenTensorFailDimOrder) {
auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast);
std::vector<Tensor::SizesType> sizes(
at_tensor.sizes().begin(), at_tensor.sizes().end());
std::vector<Tensor::DimOrderType> dim_order = {0, 1, 2, 3};
std::vector<Tensor::StridesType> strides(
at_tensor.strides().begin(), at_tensor.strides().end());
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
torch::executor::TensorImpl tensor_impl(
dtype,
at_tensor.dim(),
sizes.data(),
nullptr,
dim_order.data(),
strides.data());
torch::executor::Tensor etensor(&tensor_impl);
ET_EXPECT_DEATH(alias_etensor_to_attensor(at_tensor, etensor), "");
}

TEST(ATenBridgeTest, AliasETensorToATenTensorFailUnsupportedDimOrder) {
auto at_tensor =
at::randn({1, 2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast3d);
std::vector<Tensor::SizesType> sizes(
at_tensor.sizes().begin(), at_tensor.sizes().end());
std::vector<Tensor::DimOrderType> dim_order = {0, 2, 3, 4, 1};
std::vector<Tensor::StridesType> strides(
at_tensor.strides().begin(), at_tensor.strides().end());
auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype());
torch::executor::TensorImpl tensor_impl(
dtype,
at_tensor.dim(),
sizes.data(),
nullptr,
dim_order.data(),
strides.data());
torch::executor::Tensor etensor(&tensor_impl);
ET_EXPECT_DEATH(alias_etensor_to_attensor(at_tensor, etensor), "");
}
24 changes: 14 additions & 10 deletions extension/pybindings/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,13 +703,6 @@ struct PyModule final {
const std::string& type_str = py::str(python_input.get_type());
if (type_str == "<class 'torch.Tensor'>") {
auto at_tensor = python_input.cast<at::Tensor>();
// alias_etensor_to_attensor will assert on this later, so to better
// propogate up to python we check early and throw an exception.
if (!at_tensor.is_contiguous()) {
auto error_msg = "Input " + std::to_string(i) + "for method " +
method_name + " is not contiguous.";
throw std::runtime_error(error_msg);
}

#ifdef USE_ATEN_LIB
EValue evalue(at_tensor);
Expand All @@ -725,10 +718,21 @@ struct PyModule final {
input_strides.emplace_back(
at_tensor.strides().begin(), at_tensor.strides().end());

// Only works for MemoryFormat::Contiguous inputs
// Only works for MemoryFormat::Contiguous or MemoryFormat::ChannelsLast
// inputs
std::vector<torch::executor::Tensor::DimOrderType> dim_order;
for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) {
dim_order.push_back(cur_dim);
if (at_tensor.is_contiguous()) {
for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) {
dim_order.push_back(cur_dim);
}
} else if (
at_tensor.is_contiguous(at::MemoryFormat::ChannelsLast) &&
at_tensor.dim() == 4) {
dim_order = decltype(dim_order)({0, 2, 3, 1});
} else {
auto error_msg = "Input " + std::to_string(i) + "for method " +
method_name + " should be contiguous or channels-last.";
throw std::runtime_error(error_msg);
}
input_dim_order.push_back(std::move(dim_order));
input_tensors.emplace_back(
Expand Down
89 changes: 88 additions & 1 deletion extension/pybindings/test/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,40 @@ def get_inputs(self):
return (torch.ones(2, 2), torch.ones(2, 2))


class ModuleChannelsLast(torch.nn.Module):
"""The module to serialize and execute."""

def forward(self, x):
return torch.nn.functional.interpolate(
x,
scale_factor=2,
mode="nearest",
)

def get_methods_to_export(self):
return ("forward",)

def get_inputs(self):
return (torch.ones(1, 2, 3, 4).to(memory_format=torch.channels_last),)


class ModuleChannelsLastInDefaultOut(torch.nn.Module):
"""The module to serialize and execute."""

def forward(self, x):
return torch.nn.functional.interpolate(
x,
scale_factor=2,
mode="nearest",
).to(memory_format=torch.contiguous_format)

def get_methods_to_export(self):
return ("forward",)

def get_inputs(self):
return (torch.ones(1, 2, 3, 4).to(memory_format=torch.channels_last),)


class ModuleMulti(torch.nn.Module):
"""The module to serialize and execute."""

Expand Down Expand Up @@ -298,11 +332,61 @@ def test_constant_output_not_memory_planned(tester):
# The test module adds the input to torch.ones(2,2), so its output should be the same
# as adding them directly.
expected = torch.ones(2, 2) + torch.ones(2, 2)
tester.assertEqual(str(expected), str(executorch_output[0]))
tester.assertTrue(torch.allclose(expected, executorch_output[0]))

# The test module returns the state. Check that its value is correct.
tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))

def test_channels_last(tester) -> None:
# Create an ExecuTorch program from ModuleChannelsLast.
model = ModuleChannelsLast()
exported_program, inputs = create_program(model)

# Use pybindings to load and execute the program.
executorch_module = load_fn(exported_program.buffer)
# Inovke the callable on executorch_module instead of calling module.forward.
# Use only one input to test this case.
executorch_output = executorch_module(inputs[0])[0]

# The test module adds the two inputs, so its output should be the same
# as adding them directly.
expected = model(inputs[0])
tester.assertTrue(torch.allclose(expected, executorch_output))

def test_unsupported_dim_order(tester) -> None:
"""
Verify that the pybind layer rejects unsupported dim orders.
"""

# Create an ExecuTorch program from ModuleChannelsLast.
model = ModuleChannelsLast()
exported_program, inputs = create_program(model)
inputs = (
torch.randn(1, 2, 3, 4, 5).to(memory_format=torch.channels_last_3d),
)

# Use pybindings to load and execute the program.
executorch_module = load_fn(exported_program.buffer)

# We expect execution to error because of the invalid input dim order.
tester.assertRaises(RuntimeError, executorch_module, inputs[0])

def test_channels_last_in_default_out(tester) -> None:
# Create an ExecuTorch program from ModuleChannelsLastInDefaultOut.
model = ModuleChannelsLastInDefaultOut()
exported_program, inputs = create_program(model)

# Use pybindings to load and execute the program.
executorch_module = load_fn(exported_program.buffer)
# Inovke the callable on executorch_module instead of calling module.forward.
# Use only one input to test this case.
executorch_output = executorch_module(inputs[0])[0]

# The test module adds the two inputs, so its output should be the same
# as adding them directly.
expected = model(inputs[0])
tester.assertTrue(torch.allclose(expected, executorch_output))

def test_method_meta(tester) -> None:
exported_program, inputs = create_program(ModuleAdd())

Expand Down Expand Up @@ -388,6 +472,9 @@ def test_verification_config(tester) -> None:
test_module_single_input(tester)
test_stderr_redirect(tester)
test_quantized_ops(tester)
test_channels_last(tester)
test_channels_last_in_default_out(tester)
test_unsupported_dim_order(tester)
test_constant_output_not_memory_planned(tester)
test_method_meta(tester)
test_bad_name(tester)
Expand Down
Loading