Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Nov 1, 2024
1 parent 06d7caa commit 1d8521d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 20 deletions.
8 changes: 2 additions & 6 deletions mllib-dal/src/main/native/DecisionForestClassifierImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,6 @@ static jobject doRFClassifierOneAPICompute(
preview::train(comm, df_desc, hFeaturetable, hLabeltable);
const auto result_infer =
preview::infer(comm, df_desc, result_train.get_model(), hFeaturetable);
t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 -
t1)
.count();
jobject trees = nullptr;
if (isRoot) {
logger::println(logger::INFO, "Variable importance results:");
Expand All @@ -300,7 +295,8 @@ static jobject doRFClassifierOneAPICompute(
logger::println(logger::INFO, "Prediction results:");
printHomegenTable(result_infer.get_responses());
logger::println(logger::INFO, "Probabilities results:\n");
printHomegenTable(result_infer.get_probabilities());
printHomegenTable(result_infer.get_probabilities());ll

t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 -
Expand Down
6 changes: 0 additions & 6 deletions mllib-dal/src/main/native/DecisionForestRegressorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,6 @@ static jobject doRFRegressorOneAPICompute(
preview::train(comm, df_desc, hFeaturetable, hLabeltable);
const auto result_infer =
preview::infer(comm, df_desc, result_train.get_model(), hFeaturetable);
t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 -
t1)
.count();
jobject trees = nullptr;
if (isRoot) {
logger::println(logger::INFO, "Variable importance results:");
Expand All @@ -289,7 +284,6 @@ static jobject doRFRegressorOneAPICompute(
printHomegenTable(result_train.get_oob_err());
logger::println(logger::INFO, "Prediction results:");
printHomegenTable(result_infer.get_responses());

t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 -
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,12 @@ class RandomForestClassifierDALImpl(val uid: String,
val kvsIPPort = getOneCCLIPPort(labeledPointsTables)
val training_breakdown_name = "RFClassifier_training_breakdown_" + executorNum;

labeledPointsTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, training_breakdown_name, storePath)
Iterator.empty
}.count()
if (useDevice == "CPU") {
labeledPointsTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, training_breakdown_name, storePath)
Iterator.empty
}.count()
}
rfcTimer.record("OneCCL Init")

val results = labeledPointsTables.mapPartitionsWithIndex { (rank, tables) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ class RandomForestRegressorDALImpl(val uid: String,
val kvsIPPort = getOneCCLIPPort(labeledPointsTables)
val training_breakdown_name = "RFRegressor_training_breakdown_" + executorNum;

labeledPointsTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, training_breakdown_name, storePath)
Iterator.empty
}.count()
if (useDevice == "CPU") {
labeledPointsTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, training_breakdown_name, storePath)
Iterator.empty
}.count()
}
rfrTimer.record("OneCCL Init")

val results = labeledPointsTables.mapPartitionsWithIndex { (rank, tables) =>
Expand Down

0 comments on commit 1d8521d

Please sign in to comment.