Skip to content

Commit

Permalink
[onert] Replace LossParam with std::variant (#14681)
Browse files Browse the repository at this point in the history
This commit replaces LossParam with std::variant.

This commit applies std::variant to

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Feb 18, 2025
1 parent 5518153 commit 02e5d99
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
13 changes: 9 additions & 4 deletions runtime/onert/backend/train/KernelGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ void KernelGenerator::visit(const ir::train::operation::Loss &node)
// loss
auto back_prop_y_pred_tensor = getBackPropIn(node, y_pred_index);

auto loss_code = node.param().loss_code;
auto loss_param = node.param().loss_param;
const auto loss_code = node.param().loss_code;
const auto &loss_param = node.param().loss_param;
const auto reduction_type = node.param().reduction_type;

switch (loss_code)
Expand All @@ -418,16 +418,21 @@ void KernelGenerator::visit(const ir::train::operation::Loss &node)
{
const auto y_pred_op_code = node.y_pred_op_code();
bool is_normalization_required = (y_pred_op_code != ir::OpCode::Softmax);
const auto cce_params = std::get_if<ir::train::CategoricalCrossentropyParam>(&loss_param);
if (!cce_params)
{
throw std::runtime_error("LossLayer: Expected loss_param to be "
"CategoricalCrossentropyParam but found a different type.");
}
auto fn = std::make_unique<ops::LossCategoricalCrossentropyLayer>();
fn->configure(y_pred_tensor, y_true_tensor, output_tensor, back_prop_y_pred_tensor,
reduction_type, loss_param.cce.axis, loss_param.cce.label_smoothing,
reduction_type, cce_params->axis, cce_params->label_smoothing,
is_normalization_required);
_return_fn = std::move(fn);
break;
}
default:
throw std::runtime_error("LossLayer: unsupported loss type");
break;
}
}

Expand Down
9 changes: 5 additions & 4 deletions runtime/onert/core/include/ir/train/LossInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

#include "LossCode.h"

#include <variant>
#include <utility>

namespace onert::ir::train
{

Expand All @@ -39,13 +42,11 @@ struct LossInfo
{
LossCode loss_code;
LossReductionType reduction_type;
union LossParam {
CategoricalCrossentropyParam cce;
} loss_param;
std::variant<std::monostate, CategoricalCrossentropyParam> loss_param;

LossInfo()
: loss_code{LossCode::Undefined}, reduction_type{LossReductionType::Undefined},
loss_param{-1, 0.0f}
loss_param{CategoricalCrossentropyParam{-1, 0.0f}}
{
}
};
Expand Down

0 comments on commit 02e5d99

Please sign in to comment.