From d65e29b69c2e48902addbf7011b12dc3a1863722 Mon Sep 17 00:00:00 2001 From: minmingzhu Date: Mon, 29 Jul 2024 17:50:01 +0800 Subject: [PATCH] update --- mllib-dal/src/main/native/CorrelationImpl.cpp | 45 ++++++++++++++++--- ..._intel_oap_mllib_stat_CorrelationDALImpl.h | 2 +- .../oap/mllib/stat/CorrelationDALImpl.scala | 12 +++-- 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/mllib-dal/src/main/native/CorrelationImpl.cpp b/mllib-dal/src/main/native/CorrelationImpl.cpp index 70c4ffbba..e9de3f0c7 100644 --- a/mllib-dal/src/main/native/CorrelationImpl.cpp +++ b/mllib-dal/src/main/native/CorrelationImpl.cpp @@ -287,7 +287,7 @@ JNIEXPORT jlong JNICALL Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL( JNIEnv *env, jobject obj, jint rank, jlong pNumTabData, jlong numRows, jlong numClos, jint executorNum, jint executorCores, jint computeDeviceOrdinal, - jintArray gpuIdxArray, jstring breakdown_name, jobject resultObj) { + jintArray gpuIdxArray, jstring ip_port, jstring breakdown_name, jobject resultObj) { logger::println(logger::INFO, "oneDAL (native): use DPC++ kernels; device %s", ComputeDeviceString[computeDeviceOrdinal].c_str()); @@ -322,18 +322,52 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL( jint *gpuIndices = env->GetIntArrayElements(gpuIdxArray, 0); const char* cstr = env->GetStringUTFChars(breakdown_name, nullptr); std::string c_breakdown_name(cstr); + const char *str = env->GetStringUTFChars(ip_port, 0); + ccl::string ccl_ip_port(str); + + auto t1 = std::chrono::high_resolution_clock::now(); + logger::println(logger::INFO, "CCLInitSingleton name %s", + name); + ccl::init(); + + auto t2 = std::chrono::high_resolution_clock::now(); + auto duration = + (float)std::chrono::duration_cast(t2 - t1).count(); + + logger::println(logger::INFO, "OneCCL singleton init took %f secs", + duration / 1000); + logger::Logger::getInstance(name).printLogToFile("rankID was %d, OneCCL singleton init took %f secs.", rank, duration / 1000 ); + + + t1 = std::chrono::high_resolution_clock::now(); + logger::println(logger::INFO, "OneCCL (native): create_kvs_attr"); + + auto kvs_attr = ccl::create_kvs_attr(); + + kvs_attr.set(ccl_ip_port); + logger::println(logger::INFO, "OneCCL (native): create_main_kvs"); + + auto kvs = ccl::create_main_kvs(kvs_attr); + logger::println(logger::INFO, "OneCCL (native): g_ccl_kvs.push_back(kvs)"); + + t2 = std::chrono::high_resolution_clock::now(); + duration = + (float)std::chrono::duration_cast(t2 - t1) + .count(); + logger::println(logger::INFO, "OneCCL (native): init took %f secs", + duration / 1000); + logger::Logger::getInstance(name).printLogToFile("rankID was %d, OneCCL create communicator took %f secs.", rank, duration / 1000 ); // auto queue = getGPU(device, gpuIndices); auto device = sycl::device(sycl::gpu_selector_v); sycl::queue queue{device}; - ccl::shared_ptr_class &kvs = getKvs(); - auto t1 = std::chrono::high_resolution_clock::now(); + t1 = std::chrono::high_resolution_clock::now(); auto comm = preview::spmd::make_communicator( queue, executorNum, rank, kvs); - auto t2 = std::chrono::high_resolution_clock::now(); - auto duration = + t2 = std::chrono::high_resolution_clock::now(); + duration = (float)std::chrono::duration_cast(t2 - t1) .count(); logger::Logger::getInstance(c_breakdown_name).printLogToFile("rankID was %d, create communicator took %f secs.", rank, duration / 1000 ); @@ -342,6 +376,7 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL( env->ReleaseIntArrayElements(gpuIdxArray, gpuIndices, 0); env->ReleaseStringUTFChars(breakdown_name, cstr); + env->ReleaseStringUTFChars(ip_port, str); break; } #endif diff --git a/mllib-dal/src/main/native/javah/com_intel_oap_mllib_stat_CorrelationDALImpl.h b/mllib-dal/src/main/native/javah/com_intel_oap_mllib_stat_CorrelationDALImpl.h index 76bee182a..160649ee9 100644 --- a/mllib-dal/src/main/native/javah/com_intel_oap_mllib_stat_CorrelationDALImpl.h +++ b/mllib-dal/src/main/native/javah/com_intel_oap_mllib_stat_CorrelationDALImpl.h @@ -13,7 +13,7 @@ extern "C" { * Signature: (JJJIII[ILjava/lang/String;Lcom/intel/oap/mllib/stat/CorrelationResult;)J */ JNIEXPORT jlong JNICALL Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL - (JNIEnv *, jobject, jint, jlong, jlong, jlong, jint, jint, jint, jintArray, jstring, jobject); + (JNIEnv *, jobject, jint, jlong, jlong, jlong, jint, jint, jint, jintArray, jstring, jstring, jobject); /* * Class: com_intel_oap_mllib_stat_CorrelationDALImpl diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/CorrelationDALImpl.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/CorrelationDALImpl.scala index 0e2d35776..085a07e41 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/CorrelationDALImpl.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/stat/CorrelationDALImpl.scala @@ -65,10 +65,12 @@ class CorrelationDALImpl( Iterator.empty }.count() - coalescedTables.mapPartitionsWithIndex { (rank, table) => - OneCCL.init(executorNum, rank, kvsIPPort, training_breakdown_name, storePath) - Iterator.empty - }.count() + if (useDevice == "CPU") { + coalescedTables.mapPartitionsWithIndex { (rank, table) => + OneCCL.init(executorNum, rank, kvsIPPort, training_breakdown_name, storePath) + Iterator.empty + }.count() + } corTimer.record("OneCCL Init") val results = coalescedTables.mapPartitionsWithIndex { (rank, iter) => @@ -97,6 +99,7 @@ class CorrelationDALImpl( executorCores, computeDevice.ordinal(), gpuIndices, + kvsIPPort, training_breakdown_name, result ) @@ -148,6 +151,7 @@ class CorrelationDALImpl( executorCores: Int, computeDeviceOrdinal: Int, gpuIndices: Array[Int], + kvsIPPort: String, training_breakdown_name: String, result: CorrelationResult): Long