Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Oct 18, 2024
1 parent 4487b66 commit f8c9a59
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
21 changes: 14 additions & 7 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
#include <oneapi/ccl.hpp>

#include "CCLInitSingleton.hpp"
#include "GPU.h"
#include "Logger.h"
#include "OneCCL.h"
#include "com_intel_oap_mllib_OneCCL__.h"
#include "service.h"
#include "GPU.h"

extern const size_t ccl_root = 0;

Expand All @@ -46,8 +46,14 @@ static std::vector<ccl::shared_ptr_class<ccl::kvs>> g_kvs;

ccl::communicator &getComm() { return g_comms[0]; }
#ifdef CPU_GPU_PROFILE
static std::vector<oneapi::dal::preview::spmd::communicator<oneapi::dal::preview::spmd::device_memory_access::usm>> g_dal_comms;
oneapi::dal::preview::spmd::communicator<oneapi::dal::preview::spmd::device_memory_access::usm> &getDalComm() { return g_dal_comms[0]; }
static std::vector<oneapi::dal::preview::spmd::communicator<
oneapi::dal::preview::spmd::device_memory_access::usm>>
g_dal_comms;
oneapi::dal::preview::spmd::communicator<
oneapi::dal::preview::spmd::device_memory_access::usm> &
getDalComm() {
return g_dal_comms[0];
}

ccl::shared_ptr_class<ccl::kvs> &getKvs() { return g_kvs[0]; }
#endif
Expand All @@ -58,7 +64,7 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
logger::println(logger::INFO, "OneCCL (native): init");

#ifdef CPU_GPU_PROFILE
auto gpus = get_gpus();
auto gpus = get_gpus();
#endif

auto t1 = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -106,13 +112,14 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
duration / 1000);
sycl::queue queue{gpus[0]};
t1 = std::chrono::high_resolution_clock::now();
auto comm = oneapi::dal::preview::spmd::make_communicator<oneapi::dal::preview::spmd::backend::ccl>(
queue, size, rank, kvs);
auto comm = oneapi::dal::preview::spmd::make_communicator<
oneapi::dal::preview::spmd::backend::ccl>(queue, size, rank, kvs);
t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): create communicator took %f secs",
logger::println(logger::INFO,
"OneCCL (native): create communicator took %f secs",
duration / 1000);
g_dal_comms.push_back(comm);
#endif
Expand Down
4 changes: 3 additions & 1 deletion mllib-dal/src/main/native/OneCCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ ccl::shared_ptr_class<ccl::kvs> &getKvs();
#define ONEDAL_DATA_PARALLEL
#endif
#include "Communicator.hpp"
oneapi::dal::preview::spmd::communicator<oneapi::dal::preview::spmd::device_memory_access::usm> &getDalComm();
oneapi::dal::preview::spmd::communicator<
oneapi::dal::preview::spmd::device_memory_access::usm> &
getDalComm();
#endif
extern const size_t ccl_root;

0 comments on commit f8c9a59

Please sign in to comment.