From 6fc7ecedc034d706c6580b24d032f6d8c9aecf45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=96=91=EC=A2=85=EC=9B=90?= Date: Thu, 19 Sep 2024 15:53:47 +0900 Subject: [PATCH] [luci/service] Support CircleOutputDummy as reshape's shape This commit supports CircleOutputDummy as reshape's shape. - This is the 4th step in supporting dynamic shape inference for the reshape operation. ONE-DCO-1.0-Signed-off-by: Jongwon Yang --- .../luci/service/src/Nodes/CircleReshape.cpp | 28 +++++++- .../service/src/Nodes/CircleReshape.test.cpp | 64 +++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/compiler/luci/service/src/Nodes/CircleReshape.cpp b/compiler/luci/service/src/Nodes/CircleReshape.cpp index 10a10b5396a..0cad3c8db8e 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.cpp @@ -69,8 +69,11 @@ namespace sinf * @note CircleReshape always has two inputs: `tensor` and `shape`. * The `shape` can be CircleConst, CircleOutputDummy, or CircleNode. * - If the `shape` is CircleConst, the shape is inferred from the constant. + * - If the `shape` is CircleOutputDummy, the shape is inferred from: + * - the attribute if it exists. + * - the node itself if the attribute does not exist. * - Else, the shape is inferred from the node iteself. - * - TODO support CircleOutputDummy and CircleNode + * - TODO support CircleNode */ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) { @@ -99,6 +102,29 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node) } } } + else if (auto dummy_shape_node = dynamic_cast(node->shape())) + { + if (node->newShape()->rank() > 0) + { + shape_by_input.rank(node->newShape()->rank()); + + for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis) + { + shape_by_input.dim(axis) = node->newShape()->dim(axis); + if (node->newShape()->dim(axis) < 0) + { + shape_by_input.dim(axis).unset(); + } + } + } + else + { + // If the `shape` is CircleOutputDummy and the attribute does not exist, + // this status cannot be handled by this shape inference rule. + // TODO support no `shape` and no attribute case + shape_by_input = circle_shape(node); + } + } else { // We use shape from the node itself diff --git a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp index 0d22002d0e5..8f6175cfb9e 100644 --- a/compiler/luci/service/src/Nodes/CircleReshape.test.cpp +++ b/compiler/luci/service/src/Nodes/CircleReshape.test.cpp @@ -134,3 +134,67 @@ TEST(ShapeRuleTest, reshape_should_infer) ASSERT_TRUE(output_shape.dim(1).known()); ASSERT_EQ(4, output_shape.dim(1).value()); } + +TEST(ShapeRuleTest, reshape_by_dummy_static) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_dummy = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_dummy->dtype(loco::DataType::S32); + shape_dummy->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_dummy); + node_reshape->newShape()->rank(2); + node_reshape->newShape()->dim(0) = 6; + node_reshape->newShape()->dim(1) = 4; + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_TRUE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(6, output_shape.dim(0).value()); + ASSERT_EQ(4, output_shape.dim(1).value()); +} + +TEST(ShapeRuleTest, reshape_by_dummy_dynamic) +{ + auto g = loco::make_graph(); + auto node_reshape = g->nodes()->create(); + auto tensor_input = g->nodes()->create(); + auto shape_dummy = g->nodes()->create(); + + tensor_input->dtype(loco::DataType::S32); + tensor_input->shape({2, 3, 4}); + tensor_input->shape_status(luci::ShapeStatus::VALID); + + shape_dummy->dtype(loco::DataType::S32); + shape_dummy->shape_status(luci::ShapeStatus::VALID); + + node_reshape->tensor(tensor_input); + node_reshape->shape(shape_dummy); + node_reshape->newShape()->rank(2); + node_reshape->newShape()->dim(0) = -1; + node_reshape->newShape()->dim(1) = 4; + + loco::TensorShape output_shape; + luci::sinf::Rule shape_inf_rule; + + ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape)); + + ASSERT_EQ(2, output_shape.rank()); + ASSERT_TRUE(output_shape.dim(0).known()); + ASSERT_TRUE(output_shape.dim(1).known()); + ASSERT_EQ(6, output_shape.dim(0).value()); + ASSERT_EQ(4, output_shape.dim(1).value()); +}