Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Nov 1, 2024
1 parent 0971db6 commit 7ccf98d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mllib-dal/src/main/native/DecisionForestRegressorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ Java_com_intel_oap_mllib_regression_RandomForestRegressorDALImpl_cRFRegressorTra
t1 = std::chrono::high_resolution_clock::now();
auto comm =
preview::spmd::make_communicator<preview::spmd::backend::ccl>(
queue, executorNum, rank, kvs);
queue, executorNum, rank, kvs);:queue
t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,21 @@ class RandomForestClassifierDALImpl(val uid: String,
"Please run on GPU device.")
}
rfcTimer.record("Data Convertion")

labeledPointsTables.mapPartitionsWithIndex { (rank, iter) =>
logInfo(s"set ZE_AFFINITY_MASK")
val gpuIndices = if (useDevice == "GPU") {
val resources = TaskContext.get().resources()
resources("gpu").addresses.map(_.toInt)
} else {
null
}
logInfo(s"set ZE_AFFINITY_MASK rank is $rank.")
logInfo(s"gpuIndices is ${gpuIndices.mkString(", ")}.")
OneCCL.setExecutorEnv("ZE_AFFINITY_MASK", gpuIndices(0).toString())
Iterator.empty
}.count()

val kvsIPPort = getOneCCLIPPort(labeledPointsTables)
val training_breakdown_name = "RFClassifier_training_breakdown_" + executorNum;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ class RandomForestRegressorDALImpl(val uid: String,
}
rfrTimer.record("Data Convertion")

labeledPointsTables.mapPartitionsWithIndex { (rank, iter) =>
logInfo(s"set ZE_AFFINITY_MASK")
val gpuIndices = if (useDevice == "GPU") {
val resources = TaskContext.get().resources()
resources("gpu").addresses.map(_.toInt)
} else {
null
}
logInfo(s"set ZE_AFFINITY_MASK rank is $rank.")
logInfo(s"gpuIndices is ${gpuIndices.mkString(", ")}.")
OneCCL.setExecutorEnv("ZE_AFFINITY_MASK", gpuIndices(0).toString())
Iterator.empty
}.count()

val kvsIPPort = getOneCCLIPPort(labeledPointsTables)
val training_breakdown_name = "RFRegressor_training_breakdown_" + executorNum;

Expand Down

0 comments on commit 7ccf98d

Please sign in to comment.