diff --git a/compiler/tflchef/core/src/ModelChef.cpp b/compiler/tflchef/core/src/ModelChef.cpp index 4025ed2306f..5ccf5bdf675 100644 --- a/compiler/tflchef/core/src/ModelChef.cpp +++ b/compiler/tflchef/core/src/ModelChef.cpp @@ -47,38 +47,6 @@ using namespace souschef; namespace { -class GeneratedModelImpl final : public tflchef::GeneratedModel::Impl -{ -public: - GeneratedModelImpl(std::unique_ptr &&builder) - : _builder{std::move(builder)} - { - // DO NOTHING - } - -public: - const char *base(void) const override - { - // Return the base address of generated flatbuffer model - return reinterpret_cast(_builder->GetBufferPointer()); - } - -public: - size_t size(void) const override - { - // Return the size of generated flatbuffer model - return _builder->GetSize(); - } - -private: - std::unique_ptr _builder; -}; - -} // namespace - -namespace -{ - struct DataChefRegistry final : public Registry { }; @@ -209,6 +177,8 @@ std::set gather_customcode_set(const ::tflchef::ModelRecipe &model_ namespace { +// TODO remove +#if 0 struct CookParams { std::vector> &buffer_vec; @@ -219,6 +189,20 @@ struct CookParams std::vector &custom_code_vec; std::string noname; }; +#endif + +struct ModelChef +{ + std::unique_ptr flatbuffer_builder; + + std::vector> signdef_vec; + std::vector> buffer_vec; + std::vector> code_vec; + std::vector> subgraph_vec; + std::map builtin_code_map; + std::vector custom_code_vec; + std::string graph_name; +}; std::vector> make_dim_metadata_vec(flatbuffers::FlatBufferBuilder *flatbuffer_builder, int32_t dims_count, @@ -255,16 +239,17 @@ make_dim_metadata_vec(flatbuffers::FlatBufferBuilder *flatbuffer_builder, int32_ return dim_metadata_vec; } -template std::map cook_graph(const T &graph, CookParams &cp) +template std::map cook_graph(const T &graph, ModelChef &mc) { LOGGER(l); - std::vector> &buffer_vec = cp.buffer_vec; - std::vector> &code_vec = cp.code_vec; - std::vector> &subgraph_vec = cp.subgraph_vec; - std::unique_ptr &flatbuffer_builder = cp.flatbuffer_builder; - std::map &builtin_code_map = cp.builtin_code_map; - std::vector &custom_code_vec = cp.custom_code_vec; + // TODO remove references + std::vector> &buffer_vec = mc.buffer_vec; + std::vector> &code_vec = mc.code_vec; + std::vector> &subgraph_vec = mc.subgraph_vec; + std::unique_ptr &flatbuffer_builder = mc.flatbuffer_builder; + std::map &builtin_code_map = mc.builtin_code_map; + std::vector &custom_code_vec = mc.custom_code_vec; // Operand-related std::vector> tensor_vec; @@ -273,7 +258,7 @@ template std::map cook_graph(const T &graph, std::vector> operator_vec; // default name for graph - std::string graph_name = cp.noname; + std::string graph_name = mc.graph_name; if (graph.has_name()) graph_name = graph.name(); @@ -722,6 +707,40 @@ template std::map cook_graph(const T &graph, } // namespace +namespace +{ + +class GeneratedModelImpl final : public tflchef::GeneratedModel::Impl +{ +public: + GeneratedModelImpl() + { + // DO NOTHING + } + +public: + const char *base(void) const override + { + // Return the base address of generated flatbuffer model + return reinterpret_cast(_mc.flatbuffer_builder->GetBufferPointer()); + } + +public: + size_t size(void) const override + { + // Return the size of generated flatbuffer model + return _mc.flatbuffer_builder->GetSize(); + } + +public: + ModelChef &model_chef(void) { return _mc; } + +private: + ModelChef _mc; +}; + +} // namespace + namespace tflchef { @@ -743,26 +762,35 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe) #include "DataChef.def" #undef DATA_CHEF + std::unique_ptr gen_model(new GeneratedModelImpl()); + + ModelChef &mc = gen_model->model_chef(); + + mc.flatbuffer_builder = + std::unique_ptr(new flatbuffers::FlatBufferBuilder(1024)); + + // TODO remove references + // // Create FlatBufferBuilder // - auto flatbuffer_builder = - std::unique_ptr(new flatbuffers::FlatBufferBuilder(1024)); + std::unique_ptr &flatbuffer_builder = mc.flatbuffer_builder; // Operand-related - std::vector> buffer_vec; + std::vector> &buffer_vec = mc.buffer_vec; // Operation-related - std::vector> code_vec; + std::vector> &code_vec = mc.code_vec; // SignatureDef-related - std::vector> signdef_vec; + std::vector> &signdef_vec = mc.signdef_vec; // Graphs-related - std::vector> subgraph_vec; + std::vector> &subgraph_vec = mc.subgraph_vec; // Create OperatorCode with Builtin Operator - auto builtin_code_map = gather_builtincode_map(model_recipe); + mc.builtin_code_map = gather_builtincode_map(model_recipe); + std::map &builtin_code_map = mc.builtin_code_map; for (auto const &opcode : builtin_code_map) { tflite::OperatorCodeBuilder code_builder{*flatbuffer_builder}; @@ -788,7 +816,8 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe) // Create OperatorCode with Custom Operator std::set custom_code_set = gather_customcode_set(model_recipe); - std::vector custom_code_vec{custom_code_set.begin(), custom_code_set.end()}; + mc.custom_code_vec = {custom_code_set.begin(), custom_code_set.end()}; + std::vector &custom_code_vec = mc.custom_code_vec; for (auto opcode : custom_code_vec) { @@ -818,10 +847,9 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe) // // Create Main graph // - CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder, - builtin_code_map, custom_code_vec, "main"}; - auto table = cook_graph<::tflchef::ModelRecipe>(model_recipe, cp); + mc.graph_name = "main"; + auto table = cook_graph<::tflchef::ModelRecipe>(model_recipe, mc); symbol_tables.push_back(table); // @@ -834,10 +862,9 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe) std::ostringstream stringStream; stringStream << "sub_" << (g + 1); - CookParams cp{buffer_vec, code_vec, subgraph_vec, flatbuffer_builder, - builtin_code_map, custom_code_vec, stringStream.str()}; + mc.graph_name = stringStream.str(); - auto table = cook_graph<::tflchef::Graph>(graph, cp); + auto table = cook_graph<::tflchef::Graph>(graph, mc); symbol_tables.push_back(table); } @@ -946,8 +973,7 @@ GeneratedModel cook(const ::tflchef::ModelRecipe &model_recipe) ::tflite::FinishModelBuffer(*flatbuffer_builder, model); // Return "GenerateModel" - return GeneratedModel{ - std::unique_ptr(new GeneratedModelImpl(std::move(flatbuffer_builder)))}; + return GeneratedModel{std::move(gen_model)}; } } // namespace tflchef