Skip to content

Commit

Permalink
[onert/test] Add util to print nnfw_train_info (Samsung#12296)
Browse files Browse the repository at this point in the history
This PR aims to print nnfw_train_info in onert_train.

ONE-DCO-1.0-Signed-off-by: SeungHui Youn <[email protected]>
  • Loading branch information
zetwhite authored Dec 19, 2023
1 parent ddab4b3 commit b2c664f
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 1 deletion.
72 changes: 72 additions & 0 deletions tests/tools/onert_train/src/nnfw_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

#include <cassert>
#include <string>
#include <iostream>

#include "nnfw.h"
#include "nnfw_experimental.h"

namespace onert_train
{
Expand Down Expand Up @@ -46,4 +49,73 @@ uint64_t bufsize_for(const nnfw_tensorinfo *ti)
return elmsize[ti->dtype] * num_elems(ti);
}

std::ostream &operator<<(std::ostream &os, const NNFW_TRAIN_OPTIMIZER &opt)
{
switch (opt)
{
case NNFW_TRAIN_OPTIMIZER_ADAM:
os << "adam";
break;
case NNFW_TRAIN_OPTIMIZER_SGD:
os << "sgd";
break;
default:
os << "unsupported optimizer";
break;
}
return os;
}

std::ostream &operator<<(std::ostream &os, const NNFW_TRAIN_LOSS &loss)
{
switch (loss)
{
case NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR:
os << "mean squared error";
break;
case NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY:
os << "categorical crossentropy";
break;
default:
os << "unsupported loss";
break;
}
return os;
}

std::ostream &operator<<(std::ostream &os, const NNFW_TRAIN_LOSS_REDUCTION &loss_reduction)
{
switch (loss_reduction)
{
case NNFW_TRAIN_LOSS_REDUCTION_INVALID:
os << "use default setting";
break;
case NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE:
os << "sum over batch size";
break;
case NNFW_TRAIN_LOSS_REDUCTION_SUM:
os << "sum";
break;
default:
os << "unsupported reduction type";
break;
}
return os;
}

std::ostream &operator<<(std::ostream &os, const nnfw_loss_info &loss_info)
{
os << "{loss = " << loss_info.loss << ", reduction = " << loss_info.reduction_type << "}";
return os;
}

std::ostream &operator<<(std::ostream &os, const nnfw_train_info &info)
{
os << "- learning_rate = " << info.learning_rate << "\n";
os << "- batch_size = " << info.batch_size << "\n";
os << "- loss_info = " << info.loss_info << "\n";
os << "- optimizer = " << info.opt << "\n";
return os;
}

} // namespace onert_train
4 changes: 3 additions & 1 deletion tests/tools/onert_train/src/nnfw_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define __ONERT_TRAIN_NNFW_UTIL_H__

#include "nnfw.h"
#include "nnfw_experimental.h"

#define NNPR_ENSURE_STATUS(a) \
do \
Expand All @@ -32,6 +33,7 @@ namespace onert_train
{
uint64_t num_elems(const nnfw_tensorinfo *ti);
uint64_t bufsize_for(const nnfw_tensorinfo *ti);
} // end of namespace onert_train

std::ostream &operator<<(std::ostream &os, const nnfw_train_info &info);
} // end of namespace onert_train
#endif // __ONERT_TRAIN_NNFW_UTIL_H__
3 changes: 3 additions & 0 deletions tests/tools/onert_train/src/onert_train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ int main(const int argc, char **argv)
tri.loss_info.reduction_type = convertLossReductionType(args.getLossReductionType());
tri.opt = convertOptType(args.getOptimizerType());

std::cout << "== training parameter ==" << std::endl;
std::cout << tri;
std::cout << "========================" << std::endl;
// prepare execution

// TODO When nnfw_{prepare|run} are failed, can't catch the time
Expand Down

0 comments on commit b2c664f

Please sign in to comment.