diff --git a/runtime/onert/api/include/nnfw_experimental.h b/runtime/onert/api/include/nnfw_experimental.h index 691f7f9466e..37a56fb2156 100644 --- a/runtime/onert/api/include/nnfw_experimental.h +++ b/runtime/onert/api/include/nnfw_experimental.h @@ -176,24 +176,28 @@ NNFW_STATUS nnfw_pop_pipeline_output(nnfw_session *session, void *outputs); ////////////////////////////////////////////// typedef enum { - NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR = 0, - NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY = 1, + NNFW_TRAIN_LOSS_UNDEFINED = 0, + NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR = 1, + NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY = 2, } NNFW_TRAIN_LOSS; typedef enum { + /** Undefined */ + NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED = 0, /** Auto */ - NNFW_TRAIN_LOSS_REDUCTION_AUTO = 0, + NNFW_TRAIN_LOSS_REDUCTION_AUTO = 1, /** Scalar sum divided by number of elements in losses */ - NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE = 1, + NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE = 2, /** Scalar sum of weighted losses */ - NNFW_TRAIN_LOSS_REDUCTION_SUM = 2, + NNFW_TRAIN_LOSS_REDUCTION_SUM = 3, } NNFW_TRAIN_LOSS_REDUCTION; typedef enum { - NNFW_TRAIN_OPTIMIZER_SGD = 0, - NNFW_TRAIN_OPTIMIZER_ADAM = 1, + NNFW_TRAIN_OPTIMIZER_UNDEFINED = 0, + NNFW_TRAIN_OPTIMIZER_SGD = 1, + NNFW_TRAIN_OPTIMIZER_ADAM = 2, } NNFW_TRAIN_OPTIMIZER; typedef struct nnfw_loss_info @@ -220,6 +224,21 @@ typedef struct nnfw_train_info NNFW_TRAIN_OPTIMIZER opt = NNFW_TRAIN_OPTIMIZER_SGD; } nnfw_train_info; +/** + * @brief Get training information + * @note This function should be called after calling {@link nnfw_load_model_from_file} + * + * For the field which is not set in training information, it returns training information + * filled with default value. The default value of each field is as follows : + * learning_rate = 0.0f, batch_size = 0, *_UNDEF for other enums + * + * @param[in] session The session to get training information + * @param[out] info Training information + * + * @return @c NNFW_STATUS_NO_ERROR if successful + */ +NNFW_STATUS nnfw_train_get_traininfo(nnfw_session *session, nnfw_train_info *info); + /** * @brief Set training information * @note This function should be called after calling {@link nnfw_load_model_from_file} diff --git a/runtime/onert/api/src/nnfw_api.cc b/runtime/onert/api/src/nnfw_api.cc index 6ebaf0970f1..dd2da19eb07 100644 --- a/runtime/onert/api/src/nnfw_api.cc +++ b/runtime/onert/api/src/nnfw_api.cc @@ -388,6 +388,12 @@ NNFW_STATUS nnfw_pop_pipeline_output(nnfw_session *session, void *outputs) // Training +NNFW_STATUS nnfw_train_get_traininfo(nnfw_session *session, nnfw_train_info *info) +{ + NNFW_RETURN_ERROR_IF_NULL(session); + return session->train_get_traininfo(info); +} + NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_info *info) { NNFW_RETURN_ERROR_IF_NULL(session); diff --git a/runtime/onert/api/src/nnfw_api_internal.cc b/runtime/onert/api/src/nnfw_api_internal.cc index b28b65f6777..08719ea92cb 100644 --- a/runtime/onert/api/src/nnfw_api_internal.cc +++ b/runtime/onert/api/src/nnfw_api_internal.cc @@ -1173,6 +1173,91 @@ NNFW_STATUS nnfw_session::set_backends_per_operation(const char *backend_setting return NNFW_STATUS_NO_ERROR; } +NNFW_STATUS nnfw_session::train_get_traininfo(nnfw_train_info *info) +{ + if (isStateInitialized()) + { + // There is no _train_info in INITIALIZED, since _train_info is set when a model loaded + std::cerr << "Error during nnfw_session::train_get_traininfo : invalid state"; + return NNFW_STATUS_INVALID_STATE; + } + + if (info == nullptr) + { + std::cerr << "Error during nnfw_session::train_get_traininfo : info is nullptr" << std::endl; + return NNFW_STATUS_UNEXPECTED_NULL; + } + + // after model loaded, it ensures that _train_info is not nullptr + assert(_train_info != nullptr); + + auto convertLossCode = [](const onert::ir::train::LossCode &code) -> NNFW_TRAIN_LOSS { + switch (code) + { + case onert::ir::train::LossCode::Invalid: + return NNFW_TRAIN_LOSS_UNDEFINED; + case onert::ir::train::LossCode::MeanSquaredError: + return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR; + case onert::ir::train::LossCode::CategoricalCrossentropy: + return NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY; + default: + throw std::runtime_error{"fail to convert ir::train::LossCode"}; + } + }; + + auto convertLossReduction = + [](const onert::ir::train::LossReductionType &type) -> NNFW_TRAIN_LOSS_REDUCTION { + switch (type) + { + case onert::ir::train::LossReductionType::Invalid: + return NNFW_TRAIN_LOSS_REDUCTION_UNDEFINED; + case onert::ir::train::LossReductionType::Auto: + return NNFW_TRAIN_LOSS_REDUCTION_AUTO; + case onert::ir::train::LossReductionType::SumOverBatchSize: + return NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE; + case onert::ir::train::LossReductionType::Sum: + return NNFW_TRAIN_LOSS_REDUCTION_SUM; + default: + throw std::runtime_error{"fail to convert from ir::train::LossReductionType"}; + break; + } + }; + + auto convertOptimizerCode = + [](const onert::ir::train::OptimizerCode &code) -> NNFW_TRAIN_OPTIMIZER { + switch (code) + { + case onert::ir::train::OptimizerCode::Invalid: + return NNFW_TRAIN_OPTIMIZER_UNDEFINED; + case onert::ir::train::OptimizerCode::SGD: + return NNFW_TRAIN_OPTIMIZER_SGD; + case onert::ir::train::OptimizerCode::Adam: + return NNFW_TRAIN_OPTIMIZER_ADAM; + default: + throw std::runtime_error{"fail to convert from ir::train::OptimizerCode"}; + } + }; + + const auto loss = _train_info->lossInfo(); + const auto optim = _train_info->optimizerInfo(); + + try + { + info->learning_rate = optim.learning_rate; + info->batch_size = _train_info->batchSize(); + info->loss_info.loss = convertLossCode(loss.loss_code); + info->loss_info.reduction_type = convertLossReduction(loss.reduction_type); + info->opt = convertOptimizerCode(optim.optim_code); + } + catch (const std::exception &e) + { + std::cerr << "Error during nnfw_session::train_get_traininfo" << e.what() << std::endl; + return NNFW_STATUS_ERROR; + } + + return NNFW_STATUS_NO_ERROR; +} + NNFW_STATUS nnfw_session::train_set_traininfo(const nnfw_train_info *info) { if (not isStateModelLoaded()) diff --git a/runtime/onert/api/src/nnfw_api_internal.h b/runtime/onert/api/src/nnfw_api_internal.h index 5f803921fc6..3ae5d7c750e 100644 --- a/runtime/onert/api/src/nnfw_api_internal.h +++ b/runtime/onert/api/src/nnfw_api_internal.h @@ -170,6 +170,7 @@ struct nnfw_session */ NNFW_STATUS set_backends_per_operation(const char *backend_settings); + NNFW_STATUS train_get_traininfo(nnfw_train_info *info); NNFW_STATUS train_set_traininfo(const nnfw_train_info *info); NNFW_STATUS train_prepare(); NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);