Skip to content

Commit

Permalink
[luci/import] CircleConstNodeBuilder::build (#13879)
Browse files Browse the repository at this point in the history
This will revise CircleConstNodeBuilder::build to load extended Buffer.

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
Co-authored-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
seanshpark and jinevening authored Sep 2, 2024
1 parent 481547b commit ebe15bf
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion compiler/luci/import/src/Nodes/CircleConst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@

#include "luci/Import/Nodes/CircleConst.h"

#include "luci/Import/CircleReader.h"

#include <luci/IR/Nodes/CircleConst.h>
#include <luci/Log.h>

#include <loco.h>
#include <oops/UserExn.h>

#include <cassert>
#include <limits>
#include <ostream>
#include <string>
#include <vector>

#include <string.h>

namespace
{

Expand Down Expand Up @@ -160,7 +165,54 @@ CircleNode *CircleConstNodeBuilder::build(TensorIndex tensor_index,
const auto c_buffer = const_tensor->buffer();
const auto r_buffer = r_buffers[c_buffer];
assert(r_buffer != nullptr);
const auto buffer = wrap(r_buffer->data());
if (r_buffer->offset() == 1 || r_buffer->size() == 1)
{
// NOTE this shouldn't happen
throw std::runtime_error("Cirlce file with invalid extended Buffer.");
}
// temporary buffer to provide raw data from file
// must have life time same or longer than 'buffer' variable
std::vector<uint8_t> temp_buffer;
luci::VectorWrapper<uint8_t> buffer(nullptr);
if (r_buffer->offset() > 1)
{
if (r_buffer->size() >= std::numeric_limits<uint32_t>::max())
{
// NOTE uint32_t limit is to match "uoffset_t flatbuffers::Vector::size()"
throw std::runtime_error("Cirlce file with invalid extended Buffer.");
}
uint32_t r_size = static_cast<uint32_t>(r_buffer->size());
// match binary level to flatbuffers::Vector
temp_buffer.resize(r_size + sizeof(uint32_t));

uint8_t *t_data = temp_buffer.data();
const uint8_t *f_data = reader->file_data(r_buffer->offset());
if (f_data == nullptr)
{
// NOTE this shouldn't happen
assert(false);
return nullptr;
}
memcpy(t_data, &r_size, sizeof(r_size));
t_data = t_data + sizeof(r_size);
if (r_buffer->offset() + r_buffer->size() > reader->file_size())
{
// NOTE this shouldn't happen
assert(false);
return nullptr;
}
memcpy(t_data, f_data, r_buffer->size());

using fbv_t = flatbuffers::Vector<uint8_t>;
const fbv_t *v_data = reinterpret_cast<const fbv_t *>(temp_buffer.data());
buffer = wrap(v_data);

context->ext_buffer(true);
}
else
{
buffer = wrap(r_buffer->data());
}
const auto const_dims = wrap(const_tensor->shape()); // in NHWC
if (const_dims.size() == 0 && buffer.empty())
{
Expand Down

0 comments on commit ebe15bf

Please sign in to comment.