From bd44e63df570df1aa4cd06d78bb0c24538139c36 Mon Sep 17 00:00:00 2001 From: minmingzhu Date: Wed, 13 Nov 2024 11:19:29 +0800 Subject: [PATCH] fix cpu error --- .../src/main/native/CCLInitSingleton.hpp | 9 ++++- mllib-dal/src/main/native/OneCCL.cpp | 40 +++++-------------- .../scala/com/intel/oap/mllib/CommonJob.scala | 6 ++- 3 files changed, 21 insertions(+), 34 deletions(-) diff --git a/mllib-dal/src/main/native/CCLInitSingleton.hpp b/mllib-dal/src/main/native/CCLInitSingleton.hpp index 2805f8e3f..553060b08 100644 --- a/mllib-dal/src/main/native/CCLInitSingleton.hpp +++ b/mllib-dal/src/main/native/CCLInitSingleton.hpp @@ -41,13 +41,20 @@ class CCLInitSingleton { auto t1 = std::chrono::high_resolution_clock::now(); ccl::init(); + auto t2 = std::chrono::high_resolution_clock::now(); + auto duration = + (float)std::chrono::duration_cast(t2 - t1) + .count(); + logger::println(logger::INFO, "OneCCL (native): init took %f secs", + duration / 1000); + t1 = std::chrono::high_resolution_clock::now(); auto kvs_attr = ccl::create_kvs_attr(); kvs_attr.set(ccl_ip_port); kvs = ccl::create_main_kvs(kvs_attr); - auto t2 = std::chrono::high_resolution_clock::now(); + t2 = std::chrono::high_resolution_clock::now(); auto duration = (float)std::chrono::duration_cast(t2 - t1).count(); diff --git a/mllib-dal/src/main/native/OneCCL.cpp b/mllib-dal/src/main/native/OneCCL.cpp index 587556d63..e1500725a 100644 --- a/mllib-dal/src/main/native/OneCCL.cpp +++ b/mllib-dal/src/main/native/OneCCL.cpp @@ -61,52 +61,31 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init( jobject param) { logger::println(logger::INFO, "OneCCL (native): init"); + auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port); +#ifdef CPU_ONLY_PROFILE auto t1 = std::chrono::high_resolution_clock::now(); - - ccl::init(); + g_comms.push_back( + ccl::create_communicator(size, rank, singletonCCLInit.kvs)); auto t2 = std::chrono::high_resolution_clock::now(); auto duration = (float)std::chrono::duration_cast(t2 - t1) .count(); - logger::println(logger::INFO, "OneCCL (native): init took %f secs", + logger::println(logger::INFO, "OneCCL (native): create communicator took %f secs", duration / 1000); - const char *str = env->GetStringUTFChars(ip_port, 0); - ccl::string ccl_ip_port(str); - -#ifdef CPU_ONLY_PROFILE - auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port); - - g_kvs.push_back(singletonCCLInit.kvs); - g_comms.push_back( - ccl::create_communicator(size, rank, singletonCCLInit.kvs)); - rank_id = getComm().rank(); comm_size = getComm().size(); #endif #ifdef CPU_GPU_PROFILE - t1 = std::chrono::high_resolution_clock::now(); - auto kvs_attr = ccl::create_kvs_attr(); - - kvs_attr.set(ccl_ip_port); - - ccl::shared_ptr_class kvs = ccl::create_main_kvs(kvs_attr); - - t2 = std::chrono::high_resolution_clock::now(); - duration = - (float)std::chrono::duration_cast(t2 - t1) - .count(); - logger::println(logger::INFO, "OneCCL (native): create kvs took %f secs", - duration / 1000); auto gpus = get_gpus(); sycl::queue queue{gpus[0]}; - t1 = std::chrono::high_resolution_clock::now(); + auto t1 = std::chrono::high_resolution_clock::now(); auto comm = oneapi::dal::preview::spmd::make_communicator< - oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank, kvs); - t2 = std::chrono::high_resolution_clock::now(); - duration = + oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank, singletonCCLInit.kvs); + auto t2 = std::chrono::high_resolution_clock::now(); + auto duration = (float)std::chrono::duration_cast(t2 - t1) .count(); logger::println(logger::INFO, @@ -131,7 +110,6 @@ JNIEXPORT void JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1cleanup(JNIEnv *env, jobject obj) { logger::printerrln(logger::INFO, "OneCCL (native): cleanup"); #ifdef CPU_ONLY_PROFILE - g_kvs.pop_back(); g_comms.pop_back(); #endif #ifdef CPU_GPU_PROFILE diff --git a/mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala b/mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala index bd2a7be18..3a791ebc0 100644 --- a/mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala +++ b/mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala @@ -30,9 +30,11 @@ object CommonJob { val resources = TaskContext.get().resources() resources("gpu").addresses.map(_.toInt) } else { - null + Array.empty[Int] + } + if (gpuIndices.nonEmpty) { + OneCCL.setAffinityMask(gpuIndices(0).toString()) } - OneCCL.setAffinityMask(gpuIndices(0).toString()) Iterator.empty }.count() }