Skip to content

Commit

Permalink
[onert] Handle batch input shape inference on Reshape operator (#12860)
Browse files Browse the repository at this point in the history
This commit fixes to handle batch input shape inference on Reshape operator.

ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh <[email protected]>
  • Loading branch information
hseok-oh authored Apr 12, 2024
1 parent 3bc4433 commit 1d49752
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
4 changes: 2 additions & 2 deletions runtime/onert/core/include/util/ShapeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ ir::Shape inferPoolShape(const ir::Shape &in_shape, const ir::operation::Pool2D:

template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delta_val);

ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_elements,
const size_t total_num_elements);
ir::Shape inferReshapeShape(const ir::Shape &input_shape, const int32_t *shape_buf,
const int32_t shape_num_elements);

ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int> &axes,
bool keep_dims);
Expand Down
6 changes: 3 additions & 3 deletions runtime/onert/core/src/compiler/StaticShapeInferer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -978,8 +978,8 @@ void StaticShapeInferer::visit(const ir::operation::Reshape &op)
const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base());
assert(shape_buf);

ir::Shape new_shape = shape_inference::inferReshapeShape(
shape_buf, shape.shape().num_elements(), input.shape().num_elements());
ir::Shape new_shape =
shape_inference::inferReshapeShape(input.shape(), shape_buf, shape.shape().num_elements());

// if shape is from Const, TFLC put the shape of output into tensor
if (new_shape != output.shape())
Expand All @@ -1000,7 +1000,7 @@ void StaticShapeInferer::visit(const ir::operation::Reshape &op)
// Let's check the new_shape option
auto shape = op.param().new_shape;
ir::Shape new_shape =
shape_inference::inferReshapeShape(shape.data(), shape.size(), input.shape().num_elements());
shape_inference::inferReshapeShape(input.shape(), shape.data(), shape.size());

if (new_shape != output.shape())
{
Expand Down
8 changes: 4 additions & 4 deletions runtime/onert/core/src/exec/DynamicShapeInferer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,8 @@ void DynamicShapeInferer::visit(const ir::operation::Reshape &op)
int32_t *new_shape_buf = reinterpret_cast<int32_t *>(new_shape->buffer());
assert(new_shape_buf);

auto output_shape = shape_inference::inferReshapeShape(
new_shape_buf, new_shape->getShape().num_elements(), input->getShape().num_elements());
auto output_shape = shape_inference::inferReshapeShape(input->getShape(), new_shape_buf,
new_shape->getShape().num_elements());

// if shape is changed, change output shape and reallocate output tensor memory
if (output_shape != output->getShape() || output->buffer() == nullptr)
Expand All @@ -853,8 +853,8 @@ void DynamicShapeInferer::visit(const ir::operation::Reshape &op)
{
// Let's check the new_shape option
auto shape = op.param().new_shape;
auto output_shape = shape_inference::inferReshapeShape(shape.data(), shape.size(),
input->getShape().num_elements());
auto output_shape =
shape_inference::inferReshapeShape(input->getShape(), shape.data(), shape.size());

// if shape is changed, change output shape and reallocate output tensor memory
if (output_shape != output->getShape() || output->buffer() == nullptr)
Expand Down
15 changes: 12 additions & 3 deletions runtime/onert/core/src/util/ShapeInference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -604,11 +604,12 @@ template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delt
template ir::Shape inferRangeShape(int start_val, int limit_val, int delta_val);
template ir::Shape inferRangeShape(float start_val, float limit_val, float delta_val);

ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_elements,
const size_t total_num_elements)
ir::Shape inferReshapeShape(const ir::Shape &input_shape, const int32_t *shape_buf,
const int32_t shape_num_elements)
{
ir::Shape ret(shape_num_elements);
int32_t flatten_dim = ir::Shape::kUnspecifiedDim;
auto total_num_elements = input_shape.num_elements();
for (int32_t i = 0; i < shape_num_elements; ++i)
{
if (shape_buf[i] < 0)
Expand All @@ -628,7 +629,15 @@ ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_el

// Check reshapable
if (total_num_elements != static_cast<size_t>(ret.num_elements()))
throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
{
// Multi batch case
// TODO Handle multi batch case more precisely on runtime level
if ((ret.dim(0) == 1) &&
(total_num_elements == static_cast<size_t>(ret.num_elements() * input_shape.dim(0))))
ret.dim(0) = input_shape.dim(0);
else
throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
}

return ret;
}
Expand Down

0 comments on commit 1d49752

Please sign in to comment.