Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Aug 5, 2024
1 parent ffac9dd commit ab55d3d
Showing 1 changed file with 17 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ class CorrelationDALImpl(
val kvsIPPort = getOneCCLIPPort(coalescedTables)
val training_breakdown_name = "Correlation_training_breakdown_" + executorNum;

// coalescedTables.mapPartitionsWithIndex { (rank, iter) =>
// logInfo(s"set ZE_AFFINITY_MASK")
// val gpuIndices = if (useDevice == "GPU") {
// val resources = TaskContext.get().resources()
// resources("gpu").addresses.map(_.toInt)
// } else {
// null
// }
// logInfo(s"set ZE_AFFINITY_MASK rank is $rank.")
// logInfo(s"gpuIndices is ${gpuIndices.mkString(", ")}.")
// OneCCL.setExecutorEnv("ZE_AFFINITY_MASK", gpuIndices(0).toString())
// Iterator.empty
// }.count()
coalescedTables.mapPartitionsWithIndex { (rank, iter) =>
logInfo(s"set ZE_AFFINITY_MASK")
val gpuIndices = if (useDevice == "GPU") {
val resources = TaskContext.get().resources()
resources("gpu").addresses.map(_.toInt)
} else {
null
}
logInfo(s"set ZE_AFFINITY_MASK rank is $rank.")
logInfo(s"gpuIndices is ${gpuIndices.mkString(", ")}.")
OneCCL.setExecutorEnv("ZE_AFFINITY_MASK", gpuIndices(0).toString())
Iterator.empty
}.count()

if (useDevice == "CPU") {
coalescedTables.mapPartitionsWithIndex { (rank, table) =>
Expand Down Expand Up @@ -145,6 +145,10 @@ class CorrelationDALImpl(
}


def CorrelationSampleTrainDAL(data: RDD[Vector]) = {

}

@native private[mllib] def cCorrelationTrainDAL(rank: Int,
data: Long,
numRows: Long,
Expand Down

0 comments on commit ab55d3d

Please sign in to comment.