From b2e116244e51d95463d4158e69ba1256995c8fd3 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Mon, 31 Mar 2025 06:40:09 -0700 Subject: [PATCH] Support non-default dim orders in pybindings + aten bridge (#9736) Summary: Now that we have better dim order support, we can pass inputs into models in non-default dim orders, such as channels last. This works in the runtime, but the python layer current asserts that input tensors are in default dim order. This PR relaxes this restriction to allow channels-last input tensors. Reviewed By: kimishpatel Differential Revision: D71716100 --- extension/aten_util/aten_bridge.cpp | 27 +++-- extension/aten_util/test/aten_bridge_test.cpp | 100 ++++++++++++++++++ extension/pybindings/pybindings.cpp | 24 +++-- extension/pybindings/test/make_test.py | 89 +++++++++++++++- 4 files changed, 221 insertions(+), 19 deletions(-) diff --git a/extension/aten_util/aten_bridge.cpp b/extension/aten_util/aten_bridge.cpp index 90fa9fbf484..351919d810b 100644 --- a/extension/aten_util/aten_bridge.cpp +++ b/extension/aten_util/aten_bridge.cpp @@ -8,6 +8,8 @@ #include +#include +#include #include #include @@ -40,7 +42,12 @@ ET_CHECK_MSG( ssize_t(a.size(i)), ssize_t(b.size(i))); } - // check strides + // check strides and dim order + std::array + expected_strides{}; + runtime::dim_order_to_stride_nocheck( + b.sizes().data(), b.dim_order().data(), b.dim(), expected_strides.data()); + for (size_t i = 0, dims = a.dim(); i < dims; ++i) { // Dont match strides if the size is 1. // Why? Because tensor is non-contig only if @@ -52,6 +59,12 @@ ET_CHECK_MSG( i, ssize_t(a.stride(i)), ssize_t(b.strides()[i])); + ET_CHECK_MSG( + (b.size(i) == 1 || (b.strides()[i] == expected_strides[i])), + "Strides don't match dim order at index:%zd, stride: %zd != expected %zd", + i, + ssize_t(a.stride(i)), + ssize_t(expected_strides[i])); } // check dtype ET_CHECK_MSG( @@ -109,13 +122,11 @@ c10::ScalarType executorch_to_torch_scalar_type( void alias_etensor_to_attensor( at::Tensor& aten_tensor, torch::executor::Tensor& mutable_et) { - // TODO(kimishpatel): contiguous according to memformat - // Right now we assume everything is channels first contiguous - // Note that input tensor must be contiguous for us to alias. - // Mixing aliasing and copying is dangerous since if we aliased - // the instance of mutatble_et to aten_tensor in the previous call, - // then in the next call copying will not be the correct behavior. - ET_CHECK_MSG(aten_tensor.is_contiguous(), "Input tensor must be contiguous"); + ET_CHECK_MSG( + aten_tensor.is_contiguous() || + aten_tensor.is_contiguous(at::MemoryFormat::ChannelsLast), + "Input tensor must have contiguous or channels last memory format"); + check_tensor_meta(aten_tensor, mutable_et); mutable_et.unsafeGetTensorImpl()->set_data(aten_tensor.mutable_data_ptr()); } diff --git a/extension/aten_util/test/aten_bridge_test.cpp b/extension/aten_util/test/aten_bridge_test.cpp index ba331162fca..d529c5ea312 100644 --- a/extension/aten_util/test/aten_bridge_test.cpp +++ b/extension/aten_util/test/aten_bridge_test.cpp @@ -154,3 +154,103 @@ TEST(ATenBridgeTest, AliasTensorPtrToATenTensor) { alias_etensor_to_attensor(at_tensor, *et_tensor_ptr); EXPECT_EQ(at_tensor.const_data_ptr(), et_tensor_ptr->const_data_ptr()); } + +TEST(ATenBridgeTest, AliasATTensorToETensorChannelsLast) { + auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast); + std::vector sizes( + at_tensor.sizes().begin(), at_tensor.sizes().end()); + std::vector dim_order = {0, 2, 3, 1}; + std::vector strides( + at_tensor.strides().begin(), at_tensor.strides().end()); + auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype()); + std::vector etensor_data(at_tensor.nbytes()); + torch::executor::TensorImpl tensor_impl( + dtype, + at_tensor.dim(), + sizes.data(), + etensor_data.data(), + dim_order.data(), + strides.data()); + torch::executor::Tensor etensor(&tensor_impl); + auto aliased_at_tensor = alias_attensor_to_etensor(etensor); + EXPECT_EQ(aliased_at_tensor.const_data_ptr(), etensor_data.data()); +} + +TEST(ATenBridgeTest, AliasATTensorToETensorFailDimOrder) { + auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast); + std::vector sizes( + at_tensor.sizes().begin(), at_tensor.sizes().end()); + std::vector dim_order = {0, 1, 2, 3}; + std::vector strides( + at_tensor.strides().begin(), at_tensor.strides().end()); + auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype()); + std::vector etensor_data(at_tensor.nbytes()); + torch::executor::TensorImpl tensor_impl( + dtype, + at_tensor.dim(), + sizes.data(), + etensor_data.data(), + dim_order.data(), + strides.data()); + torch::executor::Tensor etensor(&tensor_impl); + ET_EXPECT_DEATH(alias_attensor_to_etensor(etensor), ""); +} + +TEST(ATenBridgeTest, AliasETensorToATenTensorChannelsLast) { + auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast); + std::vector sizes( + at_tensor.sizes().begin(), at_tensor.sizes().end()); + std::vector dim_order = {0, 2, 3, 1}; + std::vector strides( + at_tensor.strides().begin(), at_tensor.strides().end()); + auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype()); + torch::executor::TensorImpl tensor_impl( + dtype, + at_tensor.dim(), + sizes.data(), + nullptr, + dim_order.data(), + strides.data()); + torch::executor::Tensor etensor(&tensor_impl); + alias_etensor_to_attensor(at_tensor, etensor); + EXPECT_EQ(at_tensor.const_data_ptr(), etensor.const_data_ptr()); +} + +TEST(ATenBridgeTest, AliasETensorToATenTensorFailDimOrder) { + auto at_tensor = at::randn({2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast); + std::vector sizes( + at_tensor.sizes().begin(), at_tensor.sizes().end()); + std::vector dim_order = {0, 1, 2, 3}; + std::vector strides( + at_tensor.strides().begin(), at_tensor.strides().end()); + auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype()); + torch::executor::TensorImpl tensor_impl( + dtype, + at_tensor.dim(), + sizes.data(), + nullptr, + dim_order.data(), + strides.data()); + torch::executor::Tensor etensor(&tensor_impl); + ET_EXPECT_DEATH(alias_etensor_to_attensor(at_tensor, etensor), ""); +} + +TEST(ATenBridgeTest, AliasETensorToATenTensorFailUnsupportedDimOrder) { + auto at_tensor = + at::randn({1, 2, 3, 4, 5}).to(at::MemoryFormat::ChannelsLast3d); + std::vector sizes( + at_tensor.sizes().begin(), at_tensor.sizes().end()); + std::vector dim_order = {0, 2, 3, 4, 1}; + std::vector strides( + at_tensor.strides().begin(), at_tensor.strides().end()); + auto dtype = torchToExecuTorchScalarType(at_tensor.options().dtype()); + torch::executor::TensorImpl tensor_impl( + dtype, + at_tensor.dim(), + sizes.data(), + nullptr, + dim_order.data(), + strides.data()); + torch::executor::Tensor etensor(&tensor_impl); + ET_EXPECT_DEATH(alias_etensor_to_attensor(at_tensor, etensor), ""); +} diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index 0f2689b7068..a998e591f30 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -703,13 +703,6 @@ struct PyModule final { const std::string& type_str = py::str(python_input.get_type()); if (type_str == "") { auto at_tensor = python_input.cast(); - // alias_etensor_to_attensor will assert on this later, so to better - // propogate up to python we check early and throw an exception. - if (!at_tensor.is_contiguous()) { - auto error_msg = "Input " + std::to_string(i) + "for method " + - method_name + " is not contiguous."; - throw std::runtime_error(error_msg); - } #ifdef USE_ATEN_LIB EValue evalue(at_tensor); @@ -725,10 +718,21 @@ struct PyModule final { input_strides.emplace_back( at_tensor.strides().begin(), at_tensor.strides().end()); - // Only works for MemoryFormat::Contiguous inputs + // Only works for MemoryFormat::Contiguous or MemoryFormat::ChannelsLast + // inputs std::vector dim_order; - for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) { - dim_order.push_back(cur_dim); + if (at_tensor.is_contiguous()) { + for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) { + dim_order.push_back(cur_dim); + } + } else if ( + at_tensor.is_contiguous(at::MemoryFormat::ChannelsLast) && + at_tensor.dim() == 4) { + dim_order = decltype(dim_order)({0, 2, 3, 1}); + } else { + auto error_msg = "Input " + std::to_string(i) + "for method " + + method_name + " should be contiguous or channels-last."; + throw std::runtime_error(error_msg); } input_dim_order.push_back(std::move(dim_order)); input_tensors.emplace_back( diff --git a/extension/pybindings/test/make_test.py b/extension/pybindings/test/make_test.py index 6503b0dea18..32a695b99e2 100644 --- a/extension/pybindings/test/make_test.py +++ b/extension/pybindings/test/make_test.py @@ -32,6 +32,40 @@ def get_inputs(self): return (torch.ones(2, 2), torch.ones(2, 2)) +class ModuleChannelsLast(torch.nn.Module): + """The module to serialize and execute.""" + + def forward(self, x): + return torch.nn.functional.interpolate( + x, + scale_factor=2, + mode="nearest", + ) + + def get_methods_to_export(self): + return ("forward",) + + def get_inputs(self): + return (torch.ones(1, 2, 3, 4).to(memory_format=torch.channels_last),) + + +class ModuleChannelsLastInDefaultOut(torch.nn.Module): + """The module to serialize and execute.""" + + def forward(self, x): + return torch.nn.functional.interpolate( + x, + scale_factor=2, + mode="nearest", + ).to(memory_format=torch.contiguous_format) + + def get_methods_to_export(self): + return ("forward",) + + def get_inputs(self): + return (torch.ones(1, 2, 3, 4).to(memory_format=torch.channels_last),) + + class ModuleMulti(torch.nn.Module): """The module to serialize and execute.""" @@ -298,11 +332,61 @@ def test_constant_output_not_memory_planned(tester): # The test module adds the input to torch.ones(2,2), so its output should be the same # as adding them directly. expected = torch.ones(2, 2) + torch.ones(2, 2) - tester.assertEqual(str(expected), str(executorch_output[0])) + tester.assertTrue(torch.allclose(expected, executorch_output[0])) # The test module returns the state. Check that its value is correct. tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1])) + def test_channels_last(tester) -> None: + # Create an ExecuTorch program from ModuleChannelsLast. + model = ModuleChannelsLast() + exported_program, inputs = create_program(model) + + # Use pybindings to load and execute the program. + executorch_module = load_fn(exported_program.buffer) + # Inovke the callable on executorch_module instead of calling module.forward. + # Use only one input to test this case. + executorch_output = executorch_module(inputs[0])[0] + + # The test module adds the two inputs, so its output should be the same + # as adding them directly. + expected = model(inputs[0]) + tester.assertTrue(torch.allclose(expected, executorch_output)) + + def test_unsupported_dim_order(tester) -> None: + """ + Verify that the pybind layer rejects unsupported dim orders. + """ + + # Create an ExecuTorch program from ModuleChannelsLast. + model = ModuleChannelsLast() + exported_program, inputs = create_program(model) + inputs = ( + torch.randn(1, 2, 3, 4, 5).to(memory_format=torch.channels_last_3d), + ) + + # Use pybindings to load and execute the program. + executorch_module = load_fn(exported_program.buffer) + + # We expect execution to error because of the invalid input dim order. + tester.assertRaises(RuntimeError, executorch_module, inputs[0]) + + def test_channels_last_in_default_out(tester) -> None: + # Create an ExecuTorch program from ModuleChannelsLastInDefaultOut. + model = ModuleChannelsLastInDefaultOut() + exported_program, inputs = create_program(model) + + # Use pybindings to load and execute the program. + executorch_module = load_fn(exported_program.buffer) + # Inovke the callable on executorch_module instead of calling module.forward. + # Use only one input to test this case. + executorch_output = executorch_module(inputs[0])[0] + + # The test module adds the two inputs, so its output should be the same + # as adding them directly. + expected = model(inputs[0]) + tester.assertTrue(torch.allclose(expected, executorch_output)) + def test_method_meta(tester) -> None: exported_program, inputs = create_program(ModuleAdd()) @@ -388,6 +472,9 @@ def test_verification_config(tester) -> None: test_module_single_input(tester) test_stderr_redirect(tester) test_quantized_ops(tester) + test_channels_last(tester) + test_channels_last_in_default_out(tester) + test_unsupported_dim_order(tester) test_constant_output_not_memory_planned(tester) test_method_meta(tester) test_bad_name(tester)