2929#include  < oneapi/ccl.hpp> 
3030
3131#include  " CCLInitSingleton.hpp" 
32+ #include  " GPU.h" 
3233#include  " Logger.h" 
3334#include  " OneCCL.h" 
3435#include  " com_intel_oap_mllib_OneCCL__.h" 
3536#include  " service.h" 
36- #include  " GPU.h" 
3737
3838extern  const  size_t  ccl_root = 0 ;
3939
@@ -46,8 +46,14 @@ static std::vector<ccl::shared_ptr_class<ccl::kvs>> g_kvs;
4646
4747ccl::communicator &getComm () { return  g_comms[0 ]; }
4848#ifdef  CPU_GPU_PROFILE
49- static  std::vector<oneapi::dal::preview::spmd::communicator<oneapi::dal::preview::spmd::device_memory_access::usm>> g_dal_comms;
50- oneapi::dal::preview::spmd::communicator<oneapi::dal::preview::spmd::device_memory_access::usm> &getDalComm () { return  g_dal_comms[0 ]; }
49+ static  std::vector<oneapi::dal::preview::spmd::communicator<
50+     oneapi::dal::preview::spmd::device_memory_access::usm>>
51+     g_dal_comms;
52+ oneapi::dal::preview::spmd::communicator<
53+     oneapi::dal::preview::spmd::device_memory_access::usm> &
54+ getDalComm () {
55+     return  g_dal_comms[0 ];
56+ }
5157
5258ccl::shared_ptr_class<ccl::kvs> &getKvs () { return  g_kvs[0 ]; }
5359#endif 
@@ -58,7 +64,7 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
5864    logger::println (logger::INFO, " OneCCL (native): init" 
5965
6066#ifdef  CPU_GPU_PROFILE
61-       auto  gpus = get_gpus ();
67+     auto  gpus = get_gpus ();
6268#endif 
6369
6470    auto  t1 = std::chrono::high_resolution_clock::now ();
@@ -106,13 +112,14 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
106112                    duration / 1000 );
107113    sycl::queue queue{gpus[0 ]};
108114    t1 = std::chrono::high_resolution_clock::now ();
109-     auto  comm = oneapi::dal::preview::spmd::make_communicator<oneapi::dal::preview::spmd::backend::ccl>( 
110-         queue, size, rank, kvs);
115+     auto  comm = oneapi::dal::preview::spmd::make_communicator<
116+         oneapi::dal::preview::spmd::backend::ccl>( queue, size, rank, kvs);
111117    t2 = std::chrono::high_resolution_clock::now ();
112118    duration =
113119        (float )std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
114120            .count ();
115-     logger::println (logger::INFO, " OneCCL (native): create communicator took %f secs" 
121+     logger::println (logger::INFO,
122+                     " OneCCL (native): create communicator took %f secs" 
116123                    duration / 1000 );
117124    g_dal_comms.push_back (comm);
118125#endif 
0 commit comments