From 4988e8310eac55d50b5a26f1ea304a4933f155fd Mon Sep 17 00:00:00 2001 From: seongwoo Date: Wed, 27 Nov 2024 13:13:39 +0900 Subject: [PATCH] apply comments and add error handling logic. --- .../src/CircleInterpreter_cffi.cpp | 117 ++++++++++++++++-- 1 file changed, 105 insertions(+), 12 deletions(-) diff --git a/compiler/circle-interpreter/src/CircleInterpreter_cffi.cpp b/compiler/circle-interpreter/src/CircleInterpreter_cffi.cpp index f0b5dde9dde..523854034f5 100644 --- a/compiler/circle-interpreter/src/CircleInterpreter_cffi.cpp +++ b/compiler/circle-interpreter/src/CircleInterpreter_cffi.cpp @@ -21,10 +21,17 @@ */ #include +#include #include #include +namespace +{ + +// Global variable for error message +std::string last_error_message; + template size_t getTensorSize(const NodeT *node) { uint32_t tensor_size = luci::size(node->dtype()); @@ -33,6 +40,64 @@ template size_t getTensorSize(const NodeT *node) return tensor_size; } +// Function to retrieve the last error message. +extern "C" const char *get_last_error(void) { return last_error_message.c_str(); } + +// Clear the last error message +extern "C" void clear_last_error(void) { last_error_message.clear(); } + +/** + * @brief A function that wraps another function and catches any exceptions. + * + * This function executes the given callable (`func`) with the provided arguments. + * If the callable throws an exception, the exception message is stored in a + * `last_error_message` global variable. + * + * @tparam Func The type of the callable funciton. + * @tparam Args The types of arguments to pass to the callable function. + * @param func The callable function to execute. + * @param args The arguments to pass to the callable function. + * @return The return value of the callable function, or a default value in case of + * an exception. If the function has a `void` return type, it simply returns + * without any value. + * + * @note This function ensures that exceptions are safely caught and conveted to + * error messages that can be queried externally, e.g. from Python. + */ +template +auto exception_wrapper(Func func, Args... args) -> typename std::result_of::type +{ + using ReturnType = typename std::result_of::type; + + try + { + return func(std::forward(args)...); + } + catch (const std::exception &e) + { + last_error_message = e.what(); + if constexpr (not std::is_void::value) + { + return ReturnType{}; + } + } + catch (...) + { + last_error_message = "Unknown error"; + if constexpr (not std::is_void::value) + { + return ReturnType{}; + } + } + // For void return type + if constexpr (std::is_void::value) + { + return; + } +} + +} // namespace + /* * Q) Why do we need this wrapper class? * @@ -60,6 +125,10 @@ class InterpreterWrapper { luci::Importer importer; _module = importer.importModule(data, data_size); + if (_module == nullptr) + { + throw std::runtime_error{"Cannot import module."}; + } _intp = new luci_interpreter::Interpreter(_module.get()); } @@ -67,11 +136,19 @@ class InterpreterWrapper void interpret(void) { _intp->interpret(); } - void writeInputTensor(const int input_idx, const void *data) + void writeInputTensor(const int input_idx, const void *data, size_t input_size) { const auto input_nodes = loco::input_nodes(_module->graph()); - const auto target_input = loco::must_cast(input_nodes.at(input_idx)); - _intp->writeInputTensor(target_input, data, ::getTensorSize(target_input)); + const auto input_node = loco::must_cast(input_nodes.at(input_idx)); + // Input size from model binary + const auto fb_input_size = ::getTensorSize(input_node); + if (fb_input_size != input_size) + { + const auto msg = "Invalid input size: " + std::to_string(fb_input_size) + + " != " + std::to_string(input_size); + throw std::runtime_error(msg); + } + _intp->writeInputTensor(input_node, data, fb_input_size); } void readOutputTensor(const int output_idx, void *output, size_t output_size) @@ -79,11 +156,18 @@ class InterpreterWrapper const auto output_nodes = loco::output_nodes(_module->graph()); const auto output_node = loco::must_cast(output_nodes.at(output_idx)); - _intp->readOutputTensor(output_node, output, output_size); + const auto fb_output_size = ::getTensorSize(output_node); + if (fb_output_size != output_size) + { + const auto msg = "Invalid output size: " + std::to_string(fb_output_size) + + " != " + std::to_string(output_size); + throw std::runtime_error(msg); + } + _intp->readOutputTensor(output_node, output, fb_output_size); } private: - luci_interpreter::Interpreter *_intp; + luci_interpreter::Interpreter *_intp = nullptr; std::unique_ptr _module; }; @@ -95,23 +179,32 @@ class InterpreterWrapper * - Explicitly pass the object pointer to any funcitons that operates on the object. */ extern "C" { + InterpreterWrapper *Interpreter_new(const uint8_t *data, const size_t data_size) { - return new InterpreterWrapper(data, data_size); + return ::exception_wrapper([&]() { return new InterpreterWrapper(data, data_size); }); } -void Interpreter_delete(InterpreterWrapper *intp) { delete intp; } +void Interpreter_delete(InterpreterWrapper *intp) +{ + ::exception_wrapper([&]() { delete intp; }); +} -void Interpreter_interpret(InterpreterWrapper *intp) { intp->interpret(); } +void Interpreter_interpret(InterpreterWrapper *intp) +{ + ::exception_wrapper([&]() { intp->interpret(); }); +} -void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data) +void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data, + size_t input_size) { - intp->writeInputTensor(input_idx, data); + ::exception_wrapper([&]() { intp->writeInputTensor(input_idx, data, input_size); }); } void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output, size_t output_size) { - intp->readOutputTensor(output_idx, output, output_size); -} + ::exception_wrapper([&]() { intp->readOutputTensor(output_idx, output, output_size); }); } + +} // extern "C"