Skip to content

Commit

Permalink
[onert/api] Add nnfw_train_get_traininfo (Samsung#12391)
Browse files Browse the repository at this point in the history
This PR adds nnfw_train_get_traininfo API.
This api is to get training info from the session.

ONE-DCO-1.0-Signed-off-by: SeungHui Youn <[email protected]>
  • Loading branch information
zetwhite authored Jan 3, 2024
1 parent 4f4a24e commit 094fc01
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 7 deletions.
33 changes: 26 additions & 7 deletions runtime/onert/api/include/nnfw_experimental.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,24 +176,28 @@ NNFW_STATUS nnfw_pop_pipeline_output(nnfw_session *session, void *outputs);
//////////////////////////////////////////////
typedef enum
{
NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR = 0,
NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY = 1,
NNFW_TRAIN_LOSS_UNDEFINED = 0,
NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR = 1,
NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY = 2,
} NNFW_TRAIN_LOSS;

typedef enum
{
/** Undefined */
NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED = 0,
/** Auto */
NNFW_TRAIN_LOSS_REDUCTION_AUTO = 0,
NNFW_TRAIN_LOSS_REDUCTION_AUTO = 1,
/** Scalar sum divided by number of elements in losses */
NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE = 1,
NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE = 2,
/** Scalar sum of weighted losses */
NNFW_TRAIN_LOSS_REDUCTION_SUM = 2,
NNFW_TRAIN_LOSS_REDUCTION_SUM = 3,
} NNFW_TRAIN_LOSS_REDUCTION;

typedef enum
{
NNFW_TRAIN_OPTIMIZER_SGD = 0,
NNFW_TRAIN_OPTIMIZER_ADAM = 1,
NNFW_TRAIN_OPTIMIZER_UNDEFINED = 0,
NNFW_TRAIN_OPTIMIZER_SGD = 1,
NNFW_TRAIN_OPTIMIZER_ADAM = 2,
} NNFW_TRAIN_OPTIMIZER;

typedef struct nnfw_loss_info
Expand All @@ -220,6 +224,21 @@ typedef struct nnfw_train_info
NNFW_TRAIN_OPTIMIZER opt = NNFW_TRAIN_OPTIMIZER_SGD;
} nnfw_train_info;

/**
* @brief Get training information
* @note This function should be called after calling {@link nnfw_load_model_from_file}
*
* For the field which is not set in training information, it returns training information
* filled with default value. The default value of each field is as follows :
* learning_rate = 0.0f, batch_size = 0, *_UNDEF for other enums
*
* @param[in] session The session to get training information
* @param[out] info Training information
*
* @return @c NNFW_STATUS_NO_ERROR if successful
*/
NNFW_STATUS nnfw_train_get_traininfo(nnfw_session *session, nnfw_train_info *info);

/**
* @brief Set training information
* @note This function should be called after calling {@link nnfw_load_model_from_file}
Expand Down
6 changes: 6 additions & 0 deletions runtime/onert/api/src/nnfw_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,12 @@ NNFW_STATUS nnfw_pop_pipeline_output(nnfw_session *session, void *outputs)

// Training

NNFW_STATUS nnfw_train_get_traininfo(nnfw_session *session, nnfw_train_info *info)
{
NNFW_RETURN_ERROR_IF_NULL(session);
return session->train_get_traininfo(info);
}

NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_info *info)
{
NNFW_RETURN_ERROR_IF_NULL(session);
Expand Down
85 changes: 85 additions & 0 deletions runtime/onert/api/src/nnfw_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,91 @@ NNFW_STATUS nnfw_session::set_backends_per_operation(const char *backend_setting
return NNFW_STATUS_NO_ERROR;
}

NNFW_STATUS nnfw_session::train_get_traininfo(nnfw_train_info *info)
{
if (isStateInitialized())
{
// There is no _train_info in INITIALIZED, since _train_info is set when a model loaded
std::cerr << "Error during nnfw_session::train_get_traininfo : invalid state";
return NNFW_STATUS_INVALID_STATE;
}

if (info == nullptr)
{
std::cerr << "Error during nnfw_session::train_get_traininfo : info is nullptr" << std::endl;
return NNFW_STATUS_UNEXPECTED_NULL;
}

// after model loaded, it ensures that _train_info is not nullptr
assert(_train_info != nullptr);

auto convertLossCode = [](const onert::ir::train::LossCode &code) -> NNFW_TRAIN_LOSS {
switch (code)
{
case onert::ir::train::LossCode::Invalid:
return NNFW_TRAIN_LOSS_UNDEFINED;
case onert::ir::train::LossCode::MeanSquaredError:
return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR;
case onert::ir::train::LossCode::CategoricalCrossentropy:
return NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY;
default:
throw std::runtime_error{"fail to convert ir::train::LossCode"};
}
};

auto convertLossReduction =
[](const onert::ir::train::LossReductionType &type) -> NNFW_TRAIN_LOSS_REDUCTION {
switch (type)
{
case onert::ir::train::LossReductionType::Invalid:
return NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED;
case onert::ir::train::LossReductionType::Auto:
return NNFW_TRAIN_LOSS_REDUCTION_AUTO;
case onert::ir::train::LossReductionType::SumOverBatchSize:
return NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE;
case onert::ir::train::LossReductionType::Sum:
return NNFW_TRAIN_LOSS_REDUCTION_SUM;
default:
throw std::runtime_error{"fail to convert from ir::train::LossReductionType"};
break;
}
};

auto convertOptimizerCode =
[](const onert::ir::train::OptimizerCode &code) -> NNFW_TRAIN_OPTIMIZER {
switch (code)
{
case onert::ir::train::OptimizerCode::Invalid:
return NNFW_TRAIN_OPTIMIZER_UNDEFINED;
case onert::ir::train::OptimizerCode::SGD:
return NNFW_TRAIN_OPTIMIZER_SGD;
case onert::ir::train::OptimizerCode::Adam:
return NNFW_TRAIN_OPTIMIZER_ADAM;
default:
throw std::runtime_error{"fail to convert from ir::train::OptimizerCode"};
}
};

const auto loss = _train_info->lossInfo();
const auto optim = _train_info->optimizerInfo();

try
{
info->learning_rate = optim.learning_rate;
info->batch_size = _train_info->batchSize();
info->loss_info.loss = convertLossCode(loss.loss_code);
info->loss_info.reduction_type = convertLossReduction(loss.reduction_type);
info->opt = convertOptimizerCode(optim.optim_code);
}
catch (const std::exception &e)
{
std::cerr << "Error during nnfw_session::train_get_traininfo" << e.what() << std::endl;
return NNFW_STATUS_ERROR;
}

return NNFW_STATUS_NO_ERROR;
}

NNFW_STATUS nnfw_session::train_set_traininfo(const nnfw_train_info *info)
{
if (not isStateModelLoaded())
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/api/src/nnfw_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ struct nnfw_session
*/
NNFW_STATUS set_backends_per_operation(const char *backend_settings);

NNFW_STATUS train_get_traininfo(nnfw_train_info *info);
NNFW_STATUS train_set_traininfo(const nnfw_train_info *info);
NNFW_STATUS train_prepare();
NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
Expand Down

0 comments on commit 094fc01

Please sign in to comment.