From df6b9b8b22631e21e1a8967f39b62ef8be55e35d Mon Sep 17 00:00:00 2001 From: minmingzhu Date: Fri, 1 Nov 2024 17:56:23 +0800 Subject: [PATCH] update --- .../native/DecisionForestClassifierImpl.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/mllib-dal/src/main/native/DecisionForestClassifierImpl.cpp b/mllib-dal/src/main/native/DecisionForestClassifierImpl.cpp index 20adce205..4ebd72f4b 100644 --- a/mllib-dal/src/main/native/DecisionForestClassifierImpl.cpp +++ b/mllib-dal/src/main/native/DecisionForestClassifierImpl.cpp @@ -284,14 +284,25 @@ static jobject doRFClassifierOneAPICompute( t1 = std::chrono::high_resolution_clock::now(); const auto result_train = preview::train(comm, df_desc, hFeaturetable, hLabeltable); - const auto result_infer = - preview::infer(comm, df_desc, result_train.get_model(), hFeaturetable); - jobject trees = nullptr; if (isRoot) { logger::println(logger::INFO, "Variable importance results:"); printHomegenTable(result_train.get_var_importance()); logger::println(logger::INFO, "OOB error:"); printHomegenTable(result_train.get_oob_err()); + } + const auto result_infer = + preview::infer(comm, df_desc, result_train.get_model(), hFeaturetable); + + t2 = std::chrono::high_resolution_clock::now(); + duration = + (float)std::chrono::duration_cast(t2 - + t1) + .count(); + logger::println(logger::INFO, + "DF Classifier (native): training step took %f secs.", + duration / 1000); + jobject trees = nullptr; + if (isRoot) { logger::println(logger::INFO, "Prediction results:"); printHomegenTable(result_infer.get_responses()); logger::println(logger::INFO, "Probabilities results:\n");