Skip to content

Commit

Permalink
implementation refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mbencer committed Dec 2, 2024
1 parent a51addd commit e2005d8
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 58 deletions.
4 changes: 2 additions & 2 deletions runtime/onert/backend/cpu/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "BackendContext.h"
#include "Config.h"
#include "KernelGenerator.h"
#include "SharedMemoryOperands.h"

#include <backend/Backend.h>

Expand All @@ -45,8 +46,7 @@ class Backend : public ::onert::backend::Backend
auto &graph = *data.graph;
auto context = std::make_unique<BackendContext>(this, std::move(data));
auto tr = std::make_shared<basic::TensorRegistry>();
// TODO: Use findSharedMemoryOperandIndexes method here
auto tb = std::make_shared<TensorBuilder>(tr, ir::OperandIndexMap<ir::OperandIndex>{});
auto tb = std::make_shared<TensorBuilder>(tr, findSharedMemoryOperandIndexes(graph));
context->tensor_registry = tr;
context->tensor_builder = tb;
context->kernel_gen = std::make_shared<KernelGenerator>(graph, tb, tr, custom_kernel_builder,
Expand Down
31 changes: 20 additions & 11 deletions runtime/onert/core/include/backend/basic/BackendContextHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,30 +177,39 @@ void planTensors(const std::shared_ptr<T_TensorBuilder> &tensor_builder, const i
}

template <typename T_TensorBuilder>
ITensorRegistry *genTensors(const std::shared_ptr<T_TensorBuilder> &tensor_builder,
const ir::Graph &graph,
const util::Set<ir::OperandIndex> &external_operands,
const std::shared_ptr<ITensorRegistry> &tensor_registry,
const std::vector<onert::ir::OperationIndex> &op_order,
const ir::OperandIndexMap<ir::OperandIndex> &shared_memory_operand_idx)
ir::OperandIndexSequence register_source_memory_tensors(
const std::shared_ptr<T_TensorBuilder> &tensor_builder, const ir::Graph &graph,
const util::Set<ir::OperandIndex> &external_operands,
const ir::OperandIndexMap<ir::OperandIndex> &shared_memory_operand_idx)
{
// process source tensors for shared memory at first
std::vector<ir::OperandIndex> registered_source_ind;
// process source tensors that share memory at first
ir::OperandIndexSequence registered_source_ind;
for (const auto &[_, source_ind] : shared_memory_operand_idx)
{
if (external_operands.contains(source_ind))
continue;
if (tensor_builder->isRegistered(source_ind)) // some tensors can have the same source
continue;
tensor_builder->registerTensorInfo(source_ind, graph.operands().at(source_ind).info());
registered_source_ind.emplace_back(source_ind);
registered_source_ind.append(source_ind);
}
return registered_source_ind;
}

template <typename T_TensorBuilder>
ITensorRegistry *genTensors(const std::shared_ptr<T_TensorBuilder> &tensor_builder,
const ir::Graph &graph,
const util::Set<ir::OperandIndex> &external_operands,
const std::shared_ptr<ITensorRegistry> &tensor_registry,
const std::vector<onert::ir::OperationIndex> &op_order,
const ir::OperandIndexMap<ir::OperandIndex> &shared_memory_operand_idx)
{
const auto registered_source_ind = register_source_memory_tensors(
tensor_builder, graph, external_operands, shared_memory_operand_idx);
graph.operands().iterate([&](const ir::OperandIndex &ind, const ir::Operand &obj) {
if (external_operands.contains(ind))
return;
if (std::find(std::begin(registered_source_ind), std::end(registered_source_ind), ind) !=
std::end(registered_source_ind)) // skip tensors already registered
if (registered_source_ind.contains(ind)) // skip tensors already registered
return;
tensor_builder->registerTensorInfo(ind, obj.info());
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class StaticTensorManager

void iterate(const std::function<void(const ir::OperandIndex &)> &fn);

private:
// Update source operand index if source memory operand exist.
// Otherwise, return unchanged.
ir::OperandIndex adjust_with_memory_source_operand(const ir::OperandIndex &ind);

private:
std::unique_ptr<MemoryManager> _nonconst_mgr;
const std::shared_ptr<TensorRegistry> _tensors;
Expand Down
69 changes: 24 additions & 45 deletions runtime/onert/core/src/backend/basic/StaticTensorManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,14 @@ void StaticTensorManager::allocateNonconsts(void)

for (auto &&[ind, tensor] : _tensors->native_tensors())
{
bool buffer_set = false;
if (!tensor->is_dynamic())
const auto adjusted_ind = adjust_with_memory_source_operand(ind);
if (!_as_constants[adjusted_ind] && !tensor->is_dynamic())
{
if (_shared_memory_operand_indexes.find(ind) != std::end(_shared_memory_operand_indexes))
{
const auto &shared_memory_ind = _shared_memory_operand_indexes[ind];
if (!_as_constants[shared_memory_ind])
{
tensor->setBuffer(_nonconst_mgr->getBuffer(shared_memory_ind));
buffer_set = true;
}
}
else if (!_as_constants[ind])
{
tensor->setBuffer(_nonconst_mgr->getBuffer(ind));
buffer_set = true;
}
if (buffer_set)
{
VERBOSE(CPU_StaticTensorManager)
<< "TENSOR " << ind << " : " << static_cast<void *>(tensor->buffer()) << std::endl;
}
auto *buffer = _nonconst_mgr->getBuffer(adjusted_ind);
tensor->setBuffer(buffer);

VERBOSE(CPU_StaticTensorManager)
<< "TENSOR " << ind << " : " << static_cast<void *>(buffer) << std::endl;
}
}
}
Expand All @@ -95,14 +81,14 @@ void StaticTensorManager::buildTensor(const ir::OperandIndex &ind,
}
else
{
const auto source_operand_ind = _shared_memory_operand_indexes.find(ind);
if (source_operand_ind != std::end(_shared_memory_operand_indexes) &&
_as_constants[source_operand_ind->second])
const auto source_operand_ind = adjust_with_memory_source_operand(ind);
if (_as_constants[source_operand_ind])
{
as_const = _as_constants[source_operand_ind->second];
auto new_tensor_info = tensor_info;
new_tensor_info.setAsConstant();
// source memory tensor is a constant
tensor = std::make_unique<ExternalTensor>(new_tensor_info);
as_const = true;
}
else
{
Expand All @@ -122,16 +108,7 @@ void StaticTensorManager::claimPlan(const ir::OperandIndex &ind, uint32_t size)
// This method is called only when a tensor has proper shape
assert(!_tensors->getNativeTensor(ind)->is_dynamic());

ir::OperandIndex claim_ind;
const auto source_ind = _shared_memory_operand_indexes.find(ind);
if (source_ind == std::end(_shared_memory_operand_indexes))
{
claim_ind = ind;
}
else
{
claim_ind = source_ind->second;
}
const auto claim_ind = adjust_with_memory_source_operand(ind);
if (_as_constants[claim_ind])
{
return;
Expand All @@ -151,16 +128,7 @@ void StaticTensorManager::releasePlan(const ir::OperandIndex &ind)
// This method is called only when a tensor has proper shape
assert(!_tensors->getNativeTensor(ind)->is_dynamic());

ir::OperandIndex release_ind;
const auto source_operand_ind_ind = _shared_memory_operand_indexes.find(ind);
if (source_operand_ind_ind == std::end(_shared_memory_operand_indexes))
{
release_ind = ind;
}
else
{
release_ind = source_operand_ind_ind->second;
}
const auto release_ind = adjust_with_memory_source_operand(ind);
if (_as_constants[release_ind])
{
return;
Expand All @@ -182,6 +150,17 @@ void StaticTensorManager::iterate(const std::function<void(const ir::OperandInde
fn(it.first);
}

ir::OperandIndex StaticTensorManager::adjust_with_memory_source_operand(const ir::OperandIndex &ind)
{
const auto source_operand_ind = _shared_memory_operand_indexes.find(ind);
if (source_operand_ind != std::end(_shared_memory_operand_indexes))
{
return source_operand_ind->second;
}
// source memory operand not found
return ind;
}

} // namespace basic
} // namespace backend
} // namespace onert

0 comments on commit e2005d8

Please sign in to comment.