diff --git a/compiler/luci/pass/src/FoldCastPass.cpp b/compiler/luci/pass/src/FoldCastPass.cpp index 81b3ed65438..3ecbe2b4cbc 100644 --- a/compiler/luci/pass/src/FoldCastPass.cpp +++ b/compiler/luci/pass/src/FoldCastPass.cpp @@ -26,14 +26,20 @@ luci::CircleConst *cast_const(luci::CircleConst *node, loco::DataType from_dtype { assert(node->dtype() == from_dtype); - bool do_casting = false; + enum CAST_TYPES + { + CAST_NONE = 0, + CAST_S64_S32, + }; + + CAST_TYPES cast_type = CAST_NONE; if (from_dtype == loco::DataType::S64) { if (to_dtype == loco::DataType::S32) - do_casting = true; + cast_type = CAST_S64_S32; } // TODO: Support more data types - if (not do_casting) + if (cast_type == CAST_NONE) return nullptr; auto name = node->name(); @@ -52,9 +58,9 @@ luci::CircleConst *cast_const(luci::CircleConst *node, loco::DataType from_dtype constant->shape_status(luci::ShapeStatus::VALID); // TODO: Support more data types - if (from_dtype == loco::DataType::S64) + switch (cast_type) { - if (to_dtype == loco::DataType::S32) + case CAST_S64_S32: { constant->size(num_elems); for (uint32_t i = 0; i < num_elems; i++) @@ -64,7 +70,8 @@ luci::CircleConst *cast_const(luci::CircleConst *node, loco::DataType from_dtype constant->name(name + "_S32"); return constant; } - return nullptr; + default: + break; } return nullptr;