Skip to content

Commit df6b9b8

Browse files
committed
update
1 parent b52d382 commit df6b9b8

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

mllib-dal/src/main/native/DecisionForestClassifierImpl.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,25 @@ static jobject doRFClassifierOneAPICompute(
284284
t1 = std::chrono::high_resolution_clock::now();
285285
const auto result_train =
286286
preview::train(comm, df_desc, hFeaturetable, hLabeltable);
287-
const auto result_infer =
288-
preview::infer(comm, df_desc, result_train.get_model(), hFeaturetable);
289-
jobject trees = nullptr;
290287
if (isRoot) {
291288
logger::println(logger::INFO, "Variable importance results:");
292289
printHomegenTable(result_train.get_var_importance());
293290
logger::println(logger::INFO, "OOB error:");
294291
printHomegenTable(result_train.get_oob_err());
292+
}
293+
const auto result_infer =
294+
preview::infer(comm, df_desc, result_train.get_model(), hFeaturetable);
295+
296+
t2 = std::chrono::high_resolution_clock::now();
297+
duration =
298+
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 -
299+
t1)
300+
.count();
301+
logger::println(logger::INFO,
302+
"DF Classifier (native): training step took %f secs.",
303+
duration / 1000);
304+
jobject trees = nullptr;
305+
if (isRoot) {
295306
logger::println(logger::INFO, "Prediction results:");
296307
printHomegenTable(result_infer.get_responses());
297308
logger::println(logger::INFO, "Probabilities results:\n");

0 commit comments

Comments
 (0)