diff --git a/source/neuropod/backends/torchscript/torch_backend.cc b/source/neuropod/backends/torchscript/torch_backend.cc index e84029d2..0fa5e319 100644 --- a/source/neuropod/backends/torchscript/torch_backend.cc +++ b/source/neuropod/backends/torchscript/torch_backend.cc @@ -130,7 +130,8 @@ void insert_value_in_output(NeuropodValueMap & output, // Torch tensor // Transfer it to CPU // .to(device) is a no-op if the tensor is already transferred - auto tensor = value.toTensor().to(torch::kCPU); + // .contiguous() is a no-op if the tensor is already contiguous + auto tensor = value.toTensor().to(torch::kCPU).contiguous(); // Get the type and make a TorchNeuropodTensor auto neuropod_tensor_type = get_neuropod_type_from_torch_type(tensor.scalar_type()); diff --git a/source/python/neuropod/backends/torchscript/test/test_torchscript_packaging.py b/source/python/neuropod/backends/torchscript/test/test_torchscript_packaging.py index a3ab5c78..eac1816e 100644 --- a/source/python/neuropod/backends/torchscript/test/test_torchscript_packaging.py +++ b/source/python/neuropod/backends/torchscript/test/test_torchscript_packaging.py @@ -136,6 +136,18 @@ def forward(self, x, y): return SomeNamedTuple(sum=x + y, difference=x - y, product=x * y) +class SplitterModel(torch.jit.ScriptModule): + """ + This model returns a non-contiguous output + """ + + @torch.jit.script_method + def forward(self, x): + x1 = x[:, :2] + x2 = x[:, 2:] + return {"x1": x1, "x2": x2} + + @requires_frameworks("torchscript") class TestTorchScriptPackaging(unittest.TestCase): def package_simple_addition_model(self, do_fail=False): @@ -280,6 +292,26 @@ def test_named_tuple_model_failure(self): with self.assertRaises(ValueError): self.package_named_tuple_model(do_fail=True) + def test_noncontiguous_array(self): + # Test a non-contiguous output + x = np.arange(16).astype(np.int64).reshape(4, 4) + + with TemporaryDirectory() as test_dir: + neuropod_path = os.path.join(test_dir, "test_neuropod") + + create_torchscript_neuropod( + neuropod_path=neuropod_path, + model_name="splitter", + module=SplitterModel(), + input_spec=[{"name": "x", "dtype": "int64", "shape": (4, 4)}], + output_spec=[ + {"name": "x1", "dtype": "int64", "shape": (4, 2)}, + {"name": "x2", "dtype": "int64", "shape": (4, 2)}, + ], + test_input_data={"x": x}, + test_expected_out={"x1": x[:, :2], "x2": x[:, 2:]}, + ) + if __name__ == "__main__": unittest.main()