diff --git a/compiler/luci/pass/include/luci/CircleOptimizer.h b/compiler/luci/pass/include/luci/CircleOptimizer.h index ed7cbf611df..14323639f81 100644 --- a/compiler/luci/pass/include/luci/CircleOptimizer.h +++ b/compiler/luci/pass/include/luci/CircleOptimizer.h @@ -77,6 +77,7 @@ class CircleOptimizer final FuseActivationFunction, FusePRelu, FuseGelu, + FuseGRU, FuseRsqrt, FuseRmsNorm, FuseRoPE, diff --git a/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h b/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h new file mode 100644 index 00000000000..152dc427d95 --- /dev/null +++ b/compiler/luci/pass/include/luci/Pass/FuseGRUPass.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __LUCI_FUSE_GRU_PASS_H__ +#define __LUCI_FUSE_GRU_PASS_H__ + +#include + +namespace luci +{ + +/** + * @brief Class to fuse certain pattern of subgraph into CircleGRU + * + * For detailed subgraph pattern to be fused, please check its implementation. + */ +struct FuseGRUPass final : public logo::Pass +{ + const char *name(void) const final { return "luci::FuseGRUPass"; } + + bool run(loco::Graph *g) final; +}; + +} // namespace luci + +#endif // __LUCI_FUSE_GRU_PASS_H__ diff --git a/compiler/luci/pass/src/CircleOptimizer.cpp b/compiler/luci/pass/src/CircleOptimizer.cpp index ef6a2d86a4d..ea38e460393 100644 --- a/compiler/luci/pass/src/CircleOptimizer.cpp +++ b/compiler/luci/pass/src/CircleOptimizer.cpp @@ -52,6 +52,7 @@ #include "luci/Pass/FusePreActivationBatchNormPass.h" #include "luci/Pass/FusePReluPass.h" #include "luci/Pass/FuseGeluPass.h" +#include "luci/Pass/FuseGRUPass.h" #include "luci/Pass/FuseRsqrtPass.h" #include "luci/Pass/FuseSliceWithTConvPass.h" #include "luci/Pass/FuseHorizontalFullyConnectedPass.h" @@ -398,7 +399,7 @@ void CircleOptimizer::optimize(loco::Graph *g) const option_to_pass[Options::Algorithm::XpSepActFromTransposeConv] = &createPassInstance; option_to_pass[Options::Algorithm::ForwardReshapeToUnaryOp] = &createPassInstance; option_to_pass[Options::Algorithm::ForwardTransposeOp] = &createPassInstance; - // clang-format on + // clang-format on for (auto const &m : option_to_pass) { diff --git a/compiler/luci/pass/src/FuseGRUPass.cpp b/compiler/luci/pass/src/FuseGRUPass.cpp new file mode 100644 index 00000000000..7164c2a3d1d --- /dev/null +++ b/compiler/luci/pass/src/FuseGRUPass.cpp @@ -0,0 +1,482 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGRUPass.h" +#include "helpers/NodeFiller.h" + +#include + +#include +#include + +#include + +#include + +// Helper to fuse GRU +namespace +{ + +class GRUPattern final +{ +public: + GRUPattern(luci::CircleWhileOut *candidate) + { + assert(candidate); + _while_out_node = candidate; + } + ~GRUPattern() = default; + + bool matched(); + +public: + luci::CircleNode *_ifm = nullptr; + luci::CircleConst *_weight_ih = nullptr; + luci::CircleConst *_bias_ih = nullptr; + luci::CircleConst *_weight_hh = nullptr; + luci::CircleConst *_bias_hh = nullptr; + + luci::CircleConst *_hidden_input = nullptr; + + luci::CircleConst *_less_const = nullptr; + + luci::CircleWhile *_while_node = nullptr; + luci::CircleWhileOut *_while_out_node = nullptr; + + luci::CircleReshape *reshape = nullptr; + luci::CircleConst *reshape_shape = nullptr; + + luci::CircleAdd *add_6 = nullptr; + luci::CircleMul *mul_1 = nullptr; + luci::CircleMul *mul_3 = nullptr; + luci::CircleSub *sub_with_const = nullptr; + luci::CircleTanh *tanh = nullptr; + luci::CircleLogistic *logistic_2 = nullptr; + luci::CircleAdd *add_5 = nullptr; + luci::CircleMul *mul_2 = nullptr; + luci::CircleAdd *add_1 = nullptr; + luci::CircleSplitOut *split_1_out = nullptr; + luci::CircleSplitOut *split_2_out = nullptr; + luci::CircleSplit *split_1 = nullptr; + luci::CircleSplit *split_2 = nullptr; + luci::CircleLogistic *logistic_1 = nullptr; + luci::CircleAdd *add_4 = nullptr; + luci::CircleFullyConnected *fc_1 = nullptr; + luci::CircleFullyConnected *fc_2 = nullptr; +}; + +/** + * Below diagram shows GRU pattern to fuse. + * Note: this pattern for GRU with `return_sequences=False` + * - the below pattern will be replaced with one GRU + * Main Graph: + * [In] [CircleConst] [CircleConst] [CircleConst] [CircleConst] + * | | | | | + * V | | | | + * [CircleWhile]<----------------------------------------------------- + * | + * V + * [CircleWhileOut] + * | + * V + * [Out] + * + * Condition Graph: + * [In] [CircleConst] (scalar int32 value) + * | | + * V | + * [Less]------ + * | + * V + * [Out] + * + * Body Graph must contain: + * - 2 CircleFullyConnected nodes; + * - 3 CircleMul nodes; + * - 2 CircleLogistic nodes; + * - 2 CircleSplit nodes; + * - 6 CircleAdd nodes; + * - 1 CircleGather node; + * - 1 CircleReshape node; + * - 1 CircleSub node; + * - 1 CircleTanh node; + * - 6 CircleSplitOut nodes; + * - 5 CircleInput nodes; + * - 5 CircleOutput nodes; + * + * Body Graph: + * [In_1] [In_2]--->[Add_2 (with Const)]--->[Out_2] [In_3] + * | \ | | + * | \ [In_4]---[Gather] [Add_3 (with Const)] + * | [FullyConnected_1] | | | + * | | [Out_4] | [Out_3] + * | [Split_1] [FullyConnected_2] + * | / | \ | + * | | | \ [Split_2] + * | [Add_1] ----------------------------------------------/ | | + * | | | | | | + * | | | ------------------------------------[Add_4] | + * | | | | | + * | | | [Logistic_1] | + * | | | | | + * | | ----------------------------------------[Mul_2] | + * | | \ / + * | | [Add_5] + * | | | + * | [Logistic_2] [Tanh] + * \ / \ | + * [Mul_1] [Sub (with const)] | + * \ \ | + * \ ---------------------------[Mul_3] + * \ / + * \ / + * --------------------[Add_6]------------------------------ + * / \ + * / \ + * [Reshape] [Out_5] + * | + * [Out_1] + */ + +#define CHECK_OR_FALSE(condition) \ + if (not(condition)) \ + return false; + +bool GRUPattern::matched() +{ + // 0 - check while node + _while_node = loco::must_cast(_while_out_node->input()); + CHECK_OR_FALSE(_while_node != nullptr); + + // 1 - check condition graph + { + const auto cond_graph = _while_node->cond_graph(); + + const auto cond_nodes = loco::active_nodes(loco::output_nodes(cond_graph)); + CHECK_OR_FALSE(cond_nodes.size() == 4); + + luci::CircleLess *less_node = nullptr; + for (auto node : cond_nodes) + { + less_node = dynamic_cast(node); + if (less_node != nullptr) + break; + } + CHECK_OR_FALSE(less_node != nullptr); + + luci::CircleNode *less_input = nullptr; + CHECK_OR_FALSE(luci::fill(&less_input, &_less_const).with_commutative_args_of(less_node)); + CHECK_OR_FALSE(_less_const->dtype() == loco::DataType::S32); + CHECK_OR_FALSE(_less_const->size() == 1); + CHECK_OR_FALSE(_less_const->at(0) > 0); + } + + // 2 - Check while's input nodes + // Save hidden state input node + { + CHECK_OR_FALSE(_while_node->input_count() == 5); + + // Save input node + _ifm = loco::must_cast(_while_node->input(4)); + _hidden_input = loco::must_cast(_while_node->input(3)); + } + + // 3 - check body graph + { + const auto body_graph = _while_node->body_graph(); + + CHECK_OR_FALSE(loco::input_nodes(body_graph).size() == 5); + CHECK_OR_FALSE(loco::output_nodes(body_graph).size() == 5); + + /* Let's check the bottom part of the body graph + * --------------------[Add_6]------------------------------ + * / \ + * / \ + * [Reshape] [Out_5] + * | + * [Out_1] + */ + + const auto body_nodes = loco::active_nodes(loco::output_nodes(body_graph)); + + for (auto node : loco::active_nodes(loco::output_nodes(body_graph))) + { + reshape = dynamic_cast(node); + if (reshape) + break; + } + CHECK_OR_FALSE(reshape != nullptr); + + add_6 = loco::must_cast(reshape->tensor()); + + /* Let's check the next bottom part above add_6 + * | [Logistic_2] [Tanh] + * \ / \ | + * [Mul_1] [Sub (with const)] | + * \ \ | + * \ ---------------------------[Mul_3] + * \ / + * \ / + * --------------------[Add_6]------------------------------ + */ + + CHECK_OR_FALSE(luci::fill(&mul_1, &mul_3).with_args_of(add_6)); + CHECK_OR_FALSE(luci::fill(&sub_with_const, &tanh).with_args_of(mul_3)); + + logistic_2 = loco::must_cast(sub_with_const->y()); + + /* Let's check the next bottom part above logistic_2 + * | | | \ [Split_2] + * | [Add_1] ----------------------------------------------/ | | + * | | | | | | + * | | | ------------------------------------[Add_4] | + * | | | | | + * | | | [Logistic_1] | + * | | | | | + * | | ----------------------------------------[Mul_2] | + * | | \ / + * | | [Add_5] + * | | | + * | [Logistic_2] [Tanh] + * \ / \ | + */ + add_5 = loco::must_cast(tanh->x()); + add_1 = loco::must_cast(logistic_2->x()); + CHECK_OR_FALSE(luci::fill(&split_1_out, &split_2_out).with_commutative_args_of(add_1)); + CHECK_OR_FALSE(luci::fill(&split_2_out, &mul_2).with_commutative_args_of(add_5)); + split_2 = loco::must_cast(split_2_out->input()); + CHECK_OR_FALSE(luci::fill(&split_1_out, &logistic_1).with_commutative_args_of(mul_2)); + split_1 = loco::must_cast(split_1_out->input()); + add_4 = loco::must_cast(logistic_1->x()); + CHECK_OR_FALSE(luci::fill(&split_1_out, &split_2_out).with_args_of(add_4)); + + /* Let's check the remainig top part + * [In_1] [In_2]--->[Add_2 (with Const)]--->[Out_2] [In_3] + * | \ | | + * | \ [In_4]---[Gather] [Add_3 (with Const)] + * | [FullyConnected_1] | | | + * | | [Out_4] | [Out_3] + * | [Split_1] [FullyConnected_2] + * | / | \ | + * | | | \ [Split_2] + * | [Add_1] ----------------------------------------------/ | | + */ + fc_1 = loco::must_cast(split_1->input()); + fc_2 = loco::must_cast(split_2->input()); + + { + _weight_ih = loco::must_cast(fc_1->weights()); + _bias_ih = dynamic_cast(fc_1->bias()); + _weight_hh = loco::must_cast(fc_2->weights()); + _bias_hh = dynamic_cast(fc_2->bias()); + if (_weight_ih == nullptr or _weight_hh == nullptr) + return false; + } + } + + return true; +} + +class FuseGRU final +{ +public: + FuseGRU(const GRUPattern *p) : _p(p) {} + +public: + void apply(void); + +private: + luci::CircleGRU *create_circle_gru(loco::Graph *graph); + +private: + const GRUPattern *_p; +}; + +template +void copy_values(const luci::CircleConst *node, luci::CircleConst *cloned) +{ + assert(T == node->dtype()); + assert(T == cloned->dtype()); + + const auto size = node->size(); + cloned->size(size); + for (uint32_t i = 0; i < size; i++) + cloned->at(i) = node->at(i); +} + +luci::CircleConst *clone_circleconst(luci::CircleConst *node, loco::Graph *graph) +{ + auto cloned = graph->nodes()->create(); + + if (cloned != nullptr) + { + // dtype/shape + cloned->dtype(node->dtype()); + cloned->rank(node->rank()); + + // values + switch (node->dtype()) + { + case loco::DataType::FLOAT32: + copy_values(node, cloned); + break; + + case loco::DataType::U8: + copy_values(node, cloned); + break; + + case loco::DataType::S8: + copy_values(node, cloned); + break; + + case loco::DataType::S16: + copy_values(node, cloned); + break; + + case loco::DataType::S32: + copy_values(node, cloned); + break; + + case loco::DataType::S64: + copy_values(node, cloned); + break; + + case loco::DataType::BOOL: + copy_values(node, cloned); + break; + + default: + throw std::runtime_error("FuseGRU: Unsupported data type"); + } + } + + return cloned; +} + +luci::CircleGRU *FuseGRU::create_circle_gru(loco::Graph *graph) +{ + assert(graph); + + auto weight_ih_cloned = clone_circleconst(_p->_weight_ih, graph); + luci::copy_common_attributes(_p->_weight_ih, weight_ih_cloned); + + auto weight_hh_cloned = clone_circleconst(_p->_weight_hh, graph); + luci::copy_common_attributes(_p->_weight_hh, weight_hh_cloned); + + luci::CircleNode *bias_ih_cloned = nullptr; + if (_p->_bias_ih != nullptr) + { + bias_ih_cloned = clone_circleconst(_p->_bias_ih, graph); + luci::copy_common_attributes(_p->_bias_ih, bias_ih_cloned); + } + else + { + bias_ih_cloned = graph->nodes()->create(); + } + + luci::CircleNode *bias_hh_cloned = nullptr; + if (_p->_bias_hh != nullptr) + { + bias_hh_cloned = clone_circleconst(_p->_bias_hh, graph); + luci::copy_common_attributes(_p->_bias_hh, bias_hh_cloned); + } + else + { + bias_hh_cloned = graph->nodes()->create(); + } + + auto hidden_input_cloned = clone_circleconst(_p->_hidden_input, graph); + luci::copy_common_attributes(_p->_hidden_input, hidden_input_cloned); + + auto less_const_cloned = clone_circleconst(_p->_less_const, graph); + luci::copy_common_attributes(_p->_less_const, less_const_cloned); + + // Create and configure new CircleGRU operation. + auto circle_gru = graph->nodes()->create(); + circle_gru->input(_p->_ifm); + circle_gru->hidden_hidden(weight_hh_cloned); + circle_gru->hidden_input(weight_ih_cloned); + circle_gru->hidden_hidden_bias(bias_hh_cloned); + circle_gru->hidden_input_bias(bias_ih_cloned); + circle_gru->state(hidden_input_cloned); + + // Note: Now support only returnSequences = false + circle_gru->returnSequences(false); + circle_gru->name(_p->_while_node->name() + "_FusedCircleGRU"); + + return circle_gru; +} + +void FuseGRU::apply() +{ + auto graph = _p->_while_out_node->graph(); + + auto gru_out = create_circle_gru(graph); + + // set origin + std::vector> origin_vec{ + luci::get_origin(_p->_while_node), luci::get_origin(_p->_while_out_node), + luci::get_origin(_p->_weight_hh), luci::get_origin(_p->_weight_ih)}; + + luci::add_origin(gru_out, luci::composite_origin(origin_vec)); + + replace(_p->_while_out_node).with(gru_out); +} + +} // namespace + +namespace +{ + +bool fuse_gru(luci::CircleWhileOut *while_out_node) +{ + assert(while_out_node); + + // check first pattern + GRUPattern pattern(while_out_node); + if (pattern.matched()) + { + FuseGRU fuse(&pattern); + fuse.apply(); + return true; + } + + return false; +} + +} // namespace + +namespace luci +{ + +bool FuseGRUPass::run(loco::Graph *g) +{ + bool changed = false; + + for (auto node : loco::active_nodes(loco::output_nodes(g))) + { + auto while_out_node = dynamic_cast(node); + if (not while_out_node) + continue; + + if (fuse_gru(while_out_node)) + changed = true; + } + + return changed; +} + +} // namespace luci diff --git a/compiler/luci/pass/src/FuseGRUPass.test.cpp b/compiler/luci/pass/src/FuseGRUPass.test.cpp new file mode 100644 index 00000000000..bb9df366606 --- /dev/null +++ b/compiler/luci/pass/src/FuseGRUPass.test.cpp @@ -0,0 +1,418 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "luci/Pass/FuseGRUPass.h" + +#include + +#include + +#include + +namespace +{ + +using namespace luci::test; + +class GRUGraphlet +{ +public: + GRUGraphlet() = default; + + void init(loco::Graph *g) + { + _while_node = g->nodes()->create(5, 5); + _while_out_node = g->nodes()->create(); + _hidden_node = g->nodes()->create(); + _hidden_node->dtype(loco::DataType::FLOAT32); + _time_node = g->nodes()->create(); + _time_node->dtype(loco::DataType::FLOAT32); + _state_node = g->nodes()->create(); + _state_node->dtype(loco::DataType::FLOAT32); + + _body_graph = loco::make_graph(); + _cond_graph = loco::make_graph(); + + _less_node = _cond_graph->nodes()->create(); + _less_const_node = _cond_graph->nodes()->create(); + _less_const_node->dtype(loco::DataType::S32); + _less_const_node->size(1); + _less_const_node->at(0) = 1; + + _add_node_1 = _body_graph->nodes()->create(); + _add_node_2 = _body_graph->nodes()->create(); + _add_node_3 = _body_graph->nodes()->create(); + _add_node_4 = _body_graph->nodes()->create(); + _add_node_5 = _body_graph->nodes()->create(); + _add_node_6 = _body_graph->nodes()->create(); + + _fc_node_1 = _body_graph->nodes()->create(); + _fc_node_2 = _body_graph->nodes()->create(); + _fc_weight_1 = _body_graph->nodes()->create(); + _fc_weight_1->dtype(loco::DataType::FLOAT32); + _fc_weight_2 = _body_graph->nodes()->create(); + _fc_weight_2->dtype(loco::DataType::FLOAT32); + _fc_bias_1 = _body_graph->nodes()->create(); + _fc_bias_1->dtype(loco::DataType::FLOAT32); + _fc_bias_2 = _body_graph->nodes()->create(); + _fc_bias_2->dtype(loco::DataType::FLOAT32); + + _split_const = _body_graph->nodes()->create(); + _split_const->dtype(loco::DataType::S32); + + _logistic_node_1 = _body_graph->nodes()->create(); + _logistic_node_2 = _body_graph->nodes()->create(); + + _gather_node = _body_graph->nodes()->create(); + + _mul_node_1 = _body_graph->nodes()->create(); + _mul_node_2 = _body_graph->nodes()->create(); + _mul_node_3 = _body_graph->nodes()->create(); + + _tanh_node = _body_graph->nodes()->create(); + _sub_node = _body_graph->nodes()->create(); + + _split_node_1 = _body_graph->nodes()->create(); + _split_node_2 = _body_graph->nodes()->create(); + _split_out_node_1 = _body_graph->nodes()->create(); + _split_out_node_2 = _body_graph->nodes()->create(); + _split_out_node_3 = _body_graph->nodes()->create(); + _split_out_node_4 = _body_graph->nodes()->create(); + _split_out_node_5 = _body_graph->nodes()->create(); + _split_out_node_6 = _body_graph->nodes()->create(); + + _reshape_node = _body_graph->nodes()->create(); + + auto graph_input_cond_graph = _cond_graph->inputs()->create(); + _cond_input_node = _cond_graph->nodes()->create(); + _cond_input_node->index(graph_input_cond_graph->index()); + + auto graph_output_cond_graph = _cond_graph->outputs()->create(); + _cond_output_node = _cond_graph->nodes()->create(); + _cond_output_node->index(graph_output_cond_graph->index()); + + auto graph_input_body_graph_1 = _body_graph->inputs()->create(); + _body_input_node_1 = _body_graph->nodes()->create(); + _body_input_node_1->index(graph_input_body_graph_1->index()); + + auto graph_input_body_graph_2 = _body_graph->inputs()->create(); + _body_input_node_2 = _body_graph->nodes()->create(); + _body_input_node_2->index(graph_input_body_graph_2->index()); + + auto graph_input_body_graph_3 = _body_graph->inputs()->create(); + _body_input_node_3 = _body_graph->nodes()->create(); + _body_input_node_3->index(graph_input_body_graph_3->index()); + + auto graph_input_body_graph_4 = _body_graph->inputs()->create(); + _body_input_node_4 = _body_graph->nodes()->create(); + _body_input_node_4->index(graph_input_body_graph_4->index()); + + auto graph_input_body_graph_5 = _body_graph->inputs()->create(); + _body_input_node_5 = _body_graph->nodes()->create(); + _body_input_node_5->index(graph_input_body_graph_5->index()); + + auto graph_output_body_graph_1 = _body_graph->outputs()->create(); + _body_output_node_1 = _body_graph->nodes()->create(); + _body_output_node_1->index(graph_output_body_graph_1->index()); + + auto graph_output_body_graph_2 = _body_graph->outputs()->create(); + _body_output_node_2 = _body_graph->nodes()->create(); + _body_output_node_2->index(graph_output_body_graph_2->index()); + + auto graph_output_body_graph_3 = _body_graph->outputs()->create(); + _body_output_node_3 = _body_graph->nodes()->create(); + _body_output_node_3->index(graph_output_body_graph_3->index()); + + auto graph_output_body_graph_4 = _body_graph->outputs()->create(); + _body_output_node_4 = _body_graph->nodes()->create(); + _body_output_node_4->index(graph_output_body_graph_4->index()); + + auto graph_output_body_graph_5 = _body_graph->outputs()->create(); + _body_output_node_5 = _body_graph->nodes()->create(); + _body_output_node_5->index(graph_output_body_graph_5->index()); + } + + void invalid_less_const_type() { _less_const_node->dtype(loco::DataType::S16); } + +protected: + luci::CircleWhile *_while_node = nullptr; + luci::CircleWhileOut *_while_out_node = nullptr; + luci::CircleConst *_time_node = nullptr; + luci::CircleConst *_state_node = nullptr; + luci::CircleConst *_hidden_node = nullptr; + + luci::CircleInput *_cond_input_node = nullptr; + luci::CircleLess *_less_node = nullptr; + luci::CircleConst *_less_const_node = nullptr; + luci::CircleOutput *_cond_output_node = nullptr; + + luci::CircleInput *_body_input_node_1 = nullptr; + luci::CircleInput *_body_input_node_2 = nullptr; + luci::CircleInput *_body_input_node_3 = nullptr; + luci::CircleInput *_body_input_node_4 = nullptr; + luci::CircleInput *_body_input_node_5 = nullptr; + + luci::CircleOutput *_body_output_node_1 = nullptr; + luci::CircleOutput *_body_output_node_2 = nullptr; + luci::CircleOutput *_body_output_node_3 = nullptr; + luci::CircleOutput *_body_output_node_4 = nullptr; + luci::CircleOutput *_body_output_node_5 = nullptr; + + luci::CircleAdd *_add_node_1 = nullptr; + luci::CircleAdd *_add_node_2 = nullptr; + luci::CircleAdd *_add_node_3 = nullptr; + luci::CircleAdd *_add_node_4 = nullptr; + luci::CircleAdd *_add_node_5 = nullptr; + luci::CircleAdd *_add_node_6 = nullptr; + + luci::CircleMul *_mul_node_1 = nullptr; + luci::CircleMul *_mul_node_2 = nullptr; + luci::CircleMul *_mul_node_3 = nullptr; + + luci::CircleSub *_sub_node = nullptr; + luci::CircleTanh *_tanh_node = nullptr; + luci::CircleReshape *_reshape_node = nullptr; + luci::CircleGather *_gather_node = nullptr; + luci::CircleLogistic *_logistic_node_1 = nullptr; + luci::CircleLogistic *_logistic_node_2 = nullptr; + luci::CircleSplit *_split_node_1 = nullptr; + luci::CircleSplit *_split_node_2 = nullptr; + + luci::CircleSplitOut *_split_out_node_1 = nullptr; + luci::CircleSplitOut *_split_out_node_2 = nullptr; + luci::CircleSplitOut *_split_out_node_3 = nullptr; + luci::CircleSplitOut *_split_out_node_4 = nullptr; + luci::CircleSplitOut *_split_out_node_5 = nullptr; + luci::CircleSplitOut *_split_out_node_6 = nullptr; + + luci::CircleFullyConnected *_fc_node_1 = nullptr; + luci::CircleFullyConnected *_fc_node_2 = nullptr; + + luci::CircleConst *_split_const = nullptr; + luci::CircleConst *_fc_weight_1 = nullptr; + luci::CircleConst *_fc_bias_1 = nullptr; + luci::CircleConst *_fc_weight_2 = nullptr; + luci::CircleConst *_fc_bias_2 = nullptr; + + std::unique_ptr _cond_graph; + std::unique_ptr _body_graph; +}; + +class FuseGRUTestGraph1 : public TestIOGraph, public GRUGraphlet +{ +public: + FuseGRUTestGraph1() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GRUGraphlet::init(g()); + + _while_node->input(0, _time_node); + _while_node->input(1, _time_node); + _while_node->input(2, _state_node); + _while_node->input(3, _hidden_node); + _while_node->input(4, input()); + + _while_out_node->input(_while_node); + output()->from(_while_out_node); + + _while_node->cond_graph(_cond_graph.get()); + _while_node->body_graph(_body_graph.get()); + + // cond graph + _less_node->x(_cond_input_node); + _less_node->y(_less_const_node); + _cond_output_node->from(_less_node); + + // body graph + _add_node_1->x(_body_input_node_1); + _add_node_1->y(_split_const); + _add_node_2->x(_body_input_node_2); + _add_node_2->y(_split_const); + + _body_output_node_5->from(_add_node_1); + _body_output_node_4->from(_add_node_2); + + _gather_node->params(_body_input_node_2); + _gather_node->indices(_body_input_node_1); + _fc_node_1->input(_body_input_node_4); + _fc_node_1->weights(_fc_weight_1); + _fc_node_1->bias(_fc_bias_1); + _fc_node_2->input(_gather_node); + _fc_node_2->weights(_fc_weight_2); + _fc_node_2->bias(_fc_bias_2); + + _split_node_1->input(_fc_node_1); + _split_node_1->split_dim(_split_const); + _split_node_2->input(_fc_node_2); + _split_node_2->split_dim(_split_const); + + _split_out_node_1->input(_split_node_1); + _split_out_node_2->input(_split_node_1); + _split_out_node_3->input(_split_node_1); + + _split_out_node_4->input(_split_node_2); + _split_out_node_5->input(_split_node_2); + _split_out_node_6->input(_split_node_2); + + _add_node_3->x(_split_out_node_1); + _add_node_3->y(_split_out_node_4); + + _add_node_4->x(_split_out_node_3); + _add_node_4->y(_split_out_node_6); + + _logistic_node_1->x(_add_node_3); + + _mul_node_1->x(_body_input_node_4); + _mul_node_1->y(_logistic_node_1); + + _sub_node->y(_logistic_node_1); + _sub_node->x(_split_const); + + _logistic_node_2->x(_add_node_4); + + _mul_node_2->x(_split_out_node_2); + _mul_node_2->y(_logistic_node_2); + + _add_node_5->x(_split_out_node_5); + _add_node_5->y(_mul_node_2); + + _tanh_node->x(_add_node_5); + + _mul_node_3->x(_sub_node); + _mul_node_3->y(_tanh_node); + + _add_node_6->x(_mul_node_1); + _add_node_6->y(_mul_node_3); + + _reshape_node->shape(_add_node_6); + + _body_output_node_3->from(_reshape_node); + } +}; + +class FuseGRUTestNegGraph : public TestIOGraph, public GRUGraphlet +{ +public: + FuseGRUTestNegGraph() = default; + + void init(void) + { + TestIOGraph::init({1}, {1}); + GRUGraphlet::init(g()); + + invalid_less_const_type(); + + _while_node->input(0, _time_node); + _while_node->input(1, _time_node); + _while_node->input(2, _state_node); + _while_node->input(3, _hidden_node); + _while_node->input(4, input()); + + _while_node->cond_graph(_cond_graph.get()); + _while_node->body_graph(_body_graph.get()); + + _while_out_node->input(_while_node); + output()->from(_while_out_node); + + // cond graph + _less_node->x(_cond_input_node); + _less_node->y(_less_const_node); + _cond_output_node->from(_less_node); + + // body graph + _add_node_1->x(_body_input_node_1); + _add_node_2->x(_body_input_node_2); + + _body_output_node_5->from(_add_node_1); + _body_output_node_4->from(_add_node_2); + + _gather_node->params(_body_input_node_2); + _fc_node_1->input(_body_input_node_4); + _fc_node_1->weights(_fc_weight_1); + _fc_node_1->bias(_fc_bias_1); + _fc_node_2->input(_gather_node); + _fc_node_2->weights(_fc_weight_2); + _fc_node_2->bias(_fc_bias_2); + + _split_node_1->input(_fc_node_1); + _split_node_2->input(_fc_node_2); + + _split_out_node_1->input(_split_node_1); + _split_out_node_2->input(_split_node_1); + _split_out_node_3->input(_split_node_1); + + _split_out_node_4->input(_split_node_2); + _split_out_node_5->input(_split_node_2); + _split_out_node_6->input(_split_node_2); + + _add_node_3->x(_split_out_node_1); + _add_node_3->y(_split_out_node_4); + + _add_node_4->x(_split_out_node_3); + _add_node_4->y(_split_out_node_6); + + _logistic_node_1->x(_add_node_3); + + _mul_node_1->x(_body_input_node_4); + _mul_node_1->y(_logistic_node_1); + + _sub_node->y(_logistic_node_1); + + _logistic_node_2->x(_add_node_4); + + _mul_node_2->x(_split_out_node_2); + _mul_node_2->y(_logistic_node_2); + + _add_node_5->x(_split_out_node_5); + _add_node_5->y(_mul_node_2); + + _tanh_node->x(_add_node_5); + + _mul_node_3->x(_sub_node); + _mul_node_3->y(_tanh_node); + + _add_node_6->x(_mul_node_1); + _add_node_6->y(_mul_node_3); + + _reshape_node->shape(_add_node_6); + + _body_output_node_3->from(_reshape_node); + } +}; + +} // namespace + +TEST(FuseGRUPassTest, fuse_pattern1) +{ + FuseGRUTestGraph1 g; + luci::FuseGRUPass pass; + + g.init(); + + EXPECT_TRUE(pass.run(g.g())); +} + +TEST(FuseGRUPassTest, fuse_NEG) +{ + FuseGRUTestNegGraph g; + luci::FuseGRUPass pass; + + g.init(); + + EXPECT_FALSE(pass.run(g.g())); +}