Skip to content

Commit

Permalink
apply comments and add error handling logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
mhs4670go committed Nov 27, 2024
1 parent f5e5c42 commit 4988e83
Showing 1 changed file with 105 additions and 12 deletions.
117 changes: 105 additions & 12 deletions compiler/circle-interpreter/src/CircleInterpreter_cffi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,17 @@
*/

#include <cstddef>
#include <string>

#include <luci/Importer.h>
#include <luci_interpreter/Interpreter.h>

namespace
{

// Global variable for error message
std::string last_error_message;

template <typename NodeT> size_t getTensorSize(const NodeT *node)
{
uint32_t tensor_size = luci::size(node->dtype());
Expand All @@ -33,6 +40,64 @@ template <typename NodeT> size_t getTensorSize(const NodeT *node)
return tensor_size;
}

// Function to retrieve the last error message.
extern "C" const char *get_last_error(void) { return last_error_message.c_str(); }

// Clear the last error message
extern "C" void clear_last_error(void) { last_error_message.clear(); }

/**
* @brief A function that wraps another function and catches any exceptions.
*
* This function executes the given callable (`func`) with the provided arguments.
* If the callable throws an exception, the exception message is stored in a
* `last_error_message` global variable.
*
* @tparam Func The type of the callable funciton.
* @tparam Args The types of arguments to pass to the callable function.
* @param func The callable function to execute.
* @param args The arguments to pass to the callable function.
* @return The return value of the callable function, or a default value in case of
* an exception. If the function has a `void` return type, it simply returns
* without any value.
*
* @note This function ensures that exceptions are safely caught and conveted to
* error messages that can be queried externally, e.g. from Python.
*/
template <typename Func, typename... Args>
auto exception_wrapper(Func func, Args... args) -> typename std::result_of<Func(Args...)>::type
{
using ReturnType = typename std::result_of<Func(Args...)>::type;

try
{
return func(std::forward<Args>(args)...);
}
catch (const std::exception &e)
{
last_error_message = e.what();
if constexpr (not std::is_void<ReturnType>::value)
{
return ReturnType{};
}
}
catch (...)
{
last_error_message = "Unknown error";
if constexpr (not std::is_void<ReturnType>::value)
{
return ReturnType{};
}
}
// For void return type
if constexpr (std::is_void<ReturnType>::value)
{
return;
}
}

} // namespace

/*
* Q) Why do we need this wrapper class?
*
Expand Down Expand Up @@ -60,30 +125,49 @@ class InterpreterWrapper
{
luci::Importer importer;
_module = importer.importModule(data, data_size);
if (_module == nullptr)
{
throw std::runtime_error{"Cannot import module."};
}
_intp = new luci_interpreter::Interpreter(_module.get());
}

~InterpreterWrapper() { delete _intp; }

void interpret(void) { _intp->interpret(); }

void writeInputTensor(const int input_idx, const void *data)
void writeInputTensor(const int input_idx, const void *data, size_t input_size)
{
const auto input_nodes = loco::input_nodes(_module->graph());
const auto target_input = loco::must_cast<const luci::CircleInput *>(input_nodes.at(input_idx));
_intp->writeInputTensor(target_input, data, ::getTensorSize(target_input));
const auto input_node = loco::must_cast<const luci::CircleInput *>(input_nodes.at(input_idx));
// Input size from model binary
const auto fb_input_size = ::getTensorSize(input_node);
if (fb_input_size != input_size)
{
const auto msg = "Invalid input size: " + std::to_string(fb_input_size) +
" != " + std::to_string(input_size);
throw std::runtime_error(msg);
}
_intp->writeInputTensor(input_node, data, fb_input_size);
}

void readOutputTensor(const int output_idx, void *output, size_t output_size)
{
const auto output_nodes = loco::output_nodes(_module->graph());
const auto output_node =
loco::must_cast<const luci::CircleOutput *>(output_nodes.at(output_idx));
_intp->readOutputTensor(output_node, output, output_size);
const auto fb_output_size = ::getTensorSize(output_node);
if (fb_output_size != output_size)
{
const auto msg = "Invalid output size: " + std::to_string(fb_output_size) +
" != " + std::to_string(output_size);
throw std::runtime_error(msg);
}
_intp->readOutputTensor(output_node, output, fb_output_size);
}

private:
luci_interpreter::Interpreter *_intp;
luci_interpreter::Interpreter *_intp = nullptr;
std::unique_ptr<luci::Module> _module;
};

Expand All @@ -95,23 +179,32 @@ class InterpreterWrapper
* - Explicitly pass the object pointer to any funcitons that operates on the object.
*/
extern "C" {

InterpreterWrapper *Interpreter_new(const uint8_t *data, const size_t data_size)
{
return new InterpreterWrapper(data, data_size);
return ::exception_wrapper([&]() { return new InterpreterWrapper(data, data_size); });
}

void Interpreter_delete(InterpreterWrapper *intp) { delete intp; }
void Interpreter_delete(InterpreterWrapper *intp)
{
::exception_wrapper([&]() { delete intp; });
}

void Interpreter_interpret(InterpreterWrapper *intp) { intp->interpret(); }
void Interpreter_interpret(InterpreterWrapper *intp)
{
::exception_wrapper([&]() { intp->interpret(); });
}

void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data)
void Interpreter_writeInputTensor(InterpreterWrapper *intp, const int input_idx, const void *data,
size_t input_size)
{
intp->writeInputTensor(input_idx, data);
::exception_wrapper([&]() { intp->writeInputTensor(input_idx, data, input_size); });
}

void Interpreter_readOutputTensor(InterpreterWrapper *intp, const int output_idx, void *output,
size_t output_size)
{
intp->readOutputTensor(output_idx, output, output_size);
}
::exception_wrapper([&]() { intp->readOutputTensor(output_idx, output, output_size); });
}

} // extern "C"

0 comments on commit 4988e83

Please sign in to comment.