Skip to content

Commit

Permalink
fix cpu error
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Nov 13, 2024
1 parent 071eb75 commit bd44e63
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 34 deletions.
9 changes: 8 additions & 1 deletion mllib-dal/src/main/native/CCLInitSingleton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,20 @@ class CCLInitSingleton {
auto t1 = std::chrono::high_resolution_clock::now();

ccl::init();
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);

t1 = std::chrono::high_resolution_clock::now();
auto kvs_attr = ccl::create_kvs_attr();
kvs_attr.set<ccl::kvs_attr_id::ip_port>(ccl_ip_port);

kvs = ccl::create_main_kvs(kvs_attr);

auto t2 = std::chrono::high_resolution_clock::now();
t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count();

Expand Down
40 changes: 9 additions & 31 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,52 +61,31 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
jobject param) {

logger::println(logger::INFO, "OneCCL (native): init");
auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port);

#ifdef CPU_ONLY_PROFILE
auto t1 = std::chrono::high_resolution_clock::now();

ccl::init();
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
logger::println(logger::INFO, "OneCCL (native): create communicator took %f secs",
duration / 1000);
const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);

#ifdef CPU_ONLY_PROFILE
auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port);

g_kvs.push_back(singletonCCLInit.kvs);
g_comms.push_back(
ccl::create_communicator(size, rank, singletonCCLInit.kvs));

rank_id = getComm().rank();
comm_size = getComm().size();

#endif

#ifdef CPU_GPU_PROFILE
t1 = std::chrono::high_resolution_clock::now();
auto kvs_attr = ccl::create_kvs_attr();

kvs_attr.set<ccl::kvs_attr_id::ip_port>(ccl_ip_port);

ccl::shared_ptr_class<ccl::kvs> kvs = ccl::create_main_kvs(kvs_attr);

t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): create kvs took %f secs",
duration / 1000);
auto gpus = get_gpus();
sycl::queue queue{gpus[0]};
t1 = std::chrono::high_resolution_clock::now();
auto t1 = std::chrono::high_resolution_clock::now();
auto comm = oneapi::dal::preview::spmd::make_communicator<
oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank, kvs);
t2 = std::chrono::high_resolution_clock::now();
duration =
oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank, singletonCCLInit.kvs);
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO,
Expand All @@ -131,7 +110,6 @@ JNIEXPORT void JNICALL
Java_com_intel_oap_mllib_OneCCL_00024_c_1cleanup(JNIEnv *env, jobject obj) {
logger::printerrln(logger::INFO, "OneCCL (native): cleanup");
#ifdef CPU_ONLY_PROFILE
g_kvs.pop_back();
g_comms.pop_back();
#endif
#ifdef CPU_GPU_PROFILE
Expand Down
6 changes: 4 additions & 2 deletions mllib-dal/src/main/scala/com/intel/oap/mllib/CommonJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ object CommonJob {
val resources = TaskContext.get().resources()
resources("gpu").addresses.map(_.toInt)
} else {
null
Array.empty[Int]
}
if (gpuIndices.nonEmpty) {
OneCCL.setAffinityMask(gpuIndices(0).toString())
}
OneCCL.setAffinityMask(gpuIndices(0).toString())
Iterator.empty
}.count()
}
Expand Down

0 comments on commit bd44e63

Please sign in to comment.