Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Nov 1, 2024
1 parent b52d382 commit df6b9b8
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions mllib-dal/src/main/native/DecisionForestClassifierImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,25 @@ static jobject doRFClassifierOneAPICompute(
t1 = std::chrono::high_resolution_clock::now();
const auto result_train =
preview::train(comm, df_desc, hFeaturetable, hLabeltable);
const auto result_infer =
preview::infer(comm, df_desc, result_train.get_model(), hFeaturetable);
jobject trees = nullptr;
if (isRoot) {
logger::println(logger::INFO, "Variable importance results:");
printHomegenTable(result_train.get_var_importance());
logger::println(logger::INFO, "OOB error:");
printHomegenTable(result_train.get_oob_err());
}
const auto result_infer =
preview::infer(comm, df_desc, result_train.get_model(), hFeaturetable);

t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 -
t1)
.count();
logger::println(logger::INFO,
"DF Classifier (native): training step took %f secs.",
duration / 1000);
jobject trees = nullptr;
if (isRoot) {
logger::println(logger::INFO, "Prediction results:");
printHomegenTable(result_infer.get_responses());
logger::println(logger::INFO, "Probabilities results:\n");
Expand Down

0 comments on commit df6b9b8

Please sign in to comment.