Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Jul 18, 2024
1 parent 14ff68a commit b2c7e4d
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 23 deletions.
16 changes: 7 additions & 9 deletions mllib-dal/src/main/native/CorrelationImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,19 +228,19 @@ static void doCorrelationOneAPICompute(

JNIEXPORT jlong JNICALL
Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
JNIEnv *env, jobject obj, jlong pNumTabData, jlong numRows, jlong numClos,
JNIEnv *env, jobject obj, jint rank, jlong pNumTabData, jlong numRows, jlong numClos,
jint executorNum, jint executorCores, jint computeDeviceOrdinal,
jintArray gpuIdxArray, jstring breakdown_name, jobject resultObj) {
logger::println(logger::INFO,
"oneDAL (native): use DPC++ kernels; device %s",
ComputeDeviceString[computeDeviceOrdinal].c_str());

ccl::communicator &cclComm = getComm();
int rankId = cclComm.rank();
ComputeDevice device = getComputeDeviceByOrdinal(computeDeviceOrdinal);
switch (device) {
case ComputeDevice::host:
case ComputeDevice::cpu: {
ccl::communicator &cclComm = getComm();
int rankId = cclComm.rank();
NumericTablePtr pData = *((NumericTablePtr *)pNumTabData);
// Set number of threads for oneDAL to use for each rank
services::Environment::getInstance()->setNumberOfThreads(executorCores);
Expand All @@ -260,26 +260,24 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
logger::println(
logger::INFO,
"oneDAL (native): use GPU kernels with %d GPU(s) rankid %d", nGpu,
rankId);
rank);

jint *gpuIndices = env->GetIntArrayElements(gpuIdxArray, 0);
const char* cstr = env->GetStringUTFChars(breakdown_name, nullptr);
std::string c_breakdown_name(cstr);
int size = cclComm.size();

auto queue =
getAssignedGPU(device, cclComm, size, rankId, gpuIndices, nGpu);
auto queue = getGPU(device, gpuIndices);

ccl::shared_ptr_class<ccl::kvs> &kvs = getKvs();
auto t1 = std::chrono::high_resolution_clock::now();
auto comm =
preview::spmd::make_communicator<preview::spmd::backend::ccl>(
queue, size, rankId, kvs);
queue, executorNum, rank, kvs);
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::Logger::getInstance(c_breakdown_name).printLogToFile("rankID was %d, create communicator took %f secs.", rankId, duration / 1000 );
logger::Logger::getInstance(c_breakdown_name).printLogToFile("rankID was %d, create communicator took %f secs.", rank, duration / 1000 );
doCorrelationOneAPICompute(env, pNumTabData, numRows, numClos, comm,
resultObj, queue, c_breakdown_name);

Expand Down
17 changes: 7 additions & 10 deletions mllib-dal/src/main/native/PCAImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,19 +263,19 @@ static void doPCAOneAPICompute(

JNIEXPORT jlong JNICALL
Java_com_intel_oap_mllib_feature_PCADALImpl_cPCATrainDAL(
JNIEnv *env, jobject obj, jlong pNumTabData, jlong numRows, jlong numClos,
JNIEnv *env, jobject obj, jint rank, jlong pNumTabData, jlong numRows, jlong numClos,
jint executorNum, jint executorCores, jint computeDeviceOrdinal,
jintArray gpuIdxArray, jstring breakdown_name,jobject resultObj) {
logger::println(logger::INFO,
"oneDAL (native): use DPC++ kernels; device %s",
ComputeDeviceString[computeDeviceOrdinal].c_str());

ccl::communicator &cclComm = getComm();
size_t rankId = cclComm.rank();
ComputeDevice device = getComputeDeviceByOrdinal(computeDeviceOrdinal);
switch (device) {
case ComputeDevice::host:
case ComputeDevice::cpu: {
ccl::communicator &cclComm = getComm();
size_t rankId = cclComm.rank();
NumericTablePtr pData = *((NumericTablePtr *)pNumTabData);
// Set number of threads for oneDAL to use for each rank
services::Environment::getInstance()->setNumberOfThreads(executorCores);
Expand All @@ -295,27 +295,24 @@ Java_com_intel_oap_mllib_feature_PCADALImpl_cPCATrainDAL(
logger::println(
logger::INFO,
"oneDAL (native): use GPU kernels with %d GPU(s) rankid %d", nGpu,
rankId);
rank);

jint *gpuIndices = env->GetIntArrayElements(gpuIdxArray, 0);
const char* cstr = env->GetStringUTFChars(breakdown_name, nullptr);
std::string c_breakdown_name(cstr);

int size = cclComm.size();

auto queue =
getAssignedGPU(device, cclComm, size, rankId, gpuIndices, nGpu);
auto queue = getGPU(device, gpuIndices);

ccl::shared_ptr_class<ccl::kvs> &kvs = getKvs();
auto t1 = std::chrono::high_resolution_clock::now();
auto comm =
preview::spmd::make_communicator<preview::spmd::backend::ccl>(
queue, size, rankId, kvs);
queue, executorNum, rank, kvs);
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::Logger::getInstance(c_breakdown_name).printLogToFile("rankID was %d, create communicator took %f secs.", rankId, duration / 1000 );
logger::Logger::getInstance(c_breakdown_name).printLogToFile("rankID was %d, create communicator took %f secs.", rank, duration / 1000 );
doPCAOneAPICompute(env, pNumTabData, numRows, numClos, comm, resultObj,
queue, c_breakdown_name);
env->ReleaseIntArrayElements(gpuIdxArray, gpuIndices, 0);
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class PCADALImpl(val k: Int,
null
}
cPCATrainDAL(
rank,
tableArr,
rows,
columns,
Expand Down Expand Up @@ -222,7 +223,8 @@ class PCADALImpl(val k: Int,


// Single entry to call Correlation PCA DAL backend with parameter K
@native private[mllib] def cPCATrainDAL(data: Long,
@native private[mllib] def cPCATrainDAL(rank: Int,
data: Long,
numRows: Long,
numCols: Long,
executorNum: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class CorrelationDALImpl(
null
}
cCorrelationTrainDAL(
rank,
tableArr,
rows,
columns,
Expand Down Expand Up @@ -125,7 +126,8 @@ class CorrelationDALImpl(
}


@native private[mllib] def cCorrelationTrainDAL(data: Long,
@native private[mllib] def cCorrelationTrainDAL(rank: Int,
data: Long,
numRows: Long,
numCols: Long,
executorNum: Int,
Expand Down

0 comments on commit b2c7e4d

Please sign in to comment.