diff --git a/compiler/luci-interpreter/src/kernels/StridedSlice.cpp b/compiler/luci-interpreter/src/kernels/StridedSlice.cpp index 993256b3d1c..db45ed90668 100644 --- a/compiler/luci-interpreter/src/kernels/StridedSlice.cpp +++ b/compiler/luci-interpreter/src/kernels/StridedSlice.cpp @@ -44,7 +44,7 @@ void StridedSlice::configure() assert(begin()->element_type() == DataType::S32); assert(end()->element_type() == DataType::S32); assert(strides()->element_type() == DataType::S32); - assert(input()->shape().num_dims() <= 4); + assert(input()->shape().num_dims() <= 5); if (params().ellipsis_mask != 0) { throw std::runtime_error("ellipsis_mask is not implemented yet."); diff --git a/compiler/luci-interpreter/src/kernels/StridedSlice.test.cpp b/compiler/luci-interpreter/src/kernels/StridedSlice.test.cpp index 399cdebeda7..565530ddf1c 100644 --- a/compiler/luci-interpreter/src/kernels/StridedSlice.test.cpp +++ b/compiler/luci-interpreter/src/kernels/StridedSlice.test.cpp @@ -18,6 +18,8 @@ #include "kernels/TestUtils.h" #include "luci_interpreter/TestMemoryManager.h" +#include + namespace luci_interpreter { namespace kernels @@ -107,6 +109,48 @@ TEST(StridedSliceTest, Uint8) EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); } +TEST(StridedSliceTest, 5DCase) +{ + std::unique_ptr memory_manager = std::make_unique(); + + Shape input_shape{2, 3, 2, 2, 3}; + std::vector input_data(input_shape.num_elements()); + std::iota(std::begin(input_data), std::end(input_data), 0); + Shape begin_shape{5}; + std::vector begin_data{0, 0, 0, 0, 0}; + Shape end_shape{5}; + std::vector end_data{2, 3, 2, 2, 1}; + Shape strides_shape{5}; + std::vector strides_data{1, 1, 1, 1, 1}; + Tensor input_tensor = + makeInputTensor(input_shape, 1.0f, 0, input_data, memory_manager.get()); + Tensor begin_tensor = + makeInputTensor(begin_shape, begin_data, memory_manager.get()); + Tensor end_tensor = makeInputTensor(end_shape, end_data, memory_manager.get()); + Tensor strides_tensor = + makeInputTensor(strides_shape, strides_data, memory_manager.get()); + Tensor output_tensor = makeOutputTensor(DataType::U8, 1.0f, 0); + + StridedSliceParams params{}; + params.begin_mask = 0; + params.end_mask = 0; + params.ellipsis_mask = 0; + params.new_axis_mask = 0; + params.shrink_axis_mask = 0; + + StridedSlice kernel(&input_tensor, &begin_tensor, &end_tensor, &strides_tensor, &output_tensor, + params); + kernel.configure(); + memory_manager->allocate_memory(output_tensor); + kernel.execute(); + + std::vector output_shape{2, 3, 2, 2, 1}; + std::vector output_data{0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, + 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69}; + EXPECT_THAT(dequantizeTensorData(output_tensor), FloatArrayNear(output_data)); + EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape)); +} + } // namespace } // namespace kernels } // namespace luci_interpreter