Skip to content

Commit

Permalink
[tools/onert_train] Use optimizer type argument (#11171)
Browse files Browse the repository at this point in the history
This commit uses optimizer type argument and passes it to nnfw_train_info
structure.

ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun authored Jul 27, 2023
1 parent f66855e commit e476e73
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/tools/onert_train/src/onert_train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,25 @@ int main(const int argc, char **argv)
}
};

auto convertOptType = [](int type) {
switch (type)
{
case 0:
return NNFW_TRAIN_OPTIMIZER_SGD;
case 1:
return NNFW_TRAIN_OPTIMIZER_ADAM;
default:
std::cerr << "E: not supported optimizer type" << std::endl;
exit(-1);
}
};

// prepare training info
nnfw_train_info tri;
tri.batch_size = args.getBatchSize();
tri.learning_rate = args.getLearningRate();
tri.loss = convertLossType(args.getLossType());
tri.opt = convertOptType(args.getOptimizerType());

// prepare execution

Expand Down

0 comments on commit e476e73

Please sign in to comment.