29
29
#include < oneapi/ccl.hpp>
30
30
31
31
#include " CCLInitSingleton.hpp"
32
+ #include " GPU.h"
32
33
#include " Logger.h"
33
34
#include " OneCCL.h"
34
35
#include " com_intel_oap_mllib_OneCCL__.h"
35
36
#include " service.h"
36
- #include " GPU.h"
37
37
38
38
extern const size_t ccl_root = 0 ;
39
39
@@ -46,8 +46,14 @@ static std::vector<ccl::shared_ptr_class<ccl::kvs>> g_kvs;
46
46
47
47
ccl::communicator &getComm () { return g_comms[0 ]; }
48
48
#ifdef CPU_GPU_PROFILE
49
- static std::vector<oneapi::dal::preview::spmd::communicator<oneapi::dal::preview::spmd::device_memory_access::usm>> g_dal_comms;
50
- oneapi::dal::preview::spmd::communicator<oneapi::dal::preview::spmd::device_memory_access::usm> &getDalComm () { return g_dal_comms[0 ]; }
49
+ static std::vector<oneapi::dal::preview::spmd::communicator<
50
+ oneapi::dal::preview::spmd::device_memory_access::usm>>
51
+ g_dal_comms;
52
+ oneapi::dal::preview::spmd::communicator<
53
+ oneapi::dal::preview::spmd::device_memory_access::usm> &
54
+ getDalComm () {
55
+ return g_dal_comms[0 ];
56
+ }
51
57
52
58
ccl::shared_ptr_class<ccl::kvs> &getKvs () { return g_kvs[0 ]; }
53
59
#endif
@@ -58,7 +64,7 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
58
64
logger::println (logger::INFO, " OneCCL (native): init" );
59
65
60
66
#ifdef CPU_GPU_PROFILE
61
- auto gpus = get_gpus ();
67
+ auto gpus = get_gpus ();
62
68
#endif
63
69
64
70
auto t1 = std::chrono::high_resolution_clock::now ();
@@ -106,13 +112,14 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
106
112
duration / 1000 );
107
113
sycl::queue queue{gpus[0 ]};
108
114
t1 = std::chrono::high_resolution_clock::now ();
109
- auto comm = oneapi::dal::preview::spmd::make_communicator<oneapi::dal::preview::spmd::backend::ccl>(
110
- queue, size, rank, kvs);
115
+ auto comm = oneapi::dal::preview::spmd::make_communicator<
116
+ oneapi::dal::preview::spmd::backend::ccl>( queue, size, rank, kvs);
111
117
t2 = std::chrono::high_resolution_clock::now ();
112
118
duration =
113
119
(float )std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
114
120
.count ();
115
- logger::println (logger::INFO, " OneCCL (native): create communicator took %f secs" ,
121
+ logger::println (logger::INFO,
122
+ " OneCCL (native): create communicator took %f secs" ,
116
123
duration / 1000 );
117
124
g_dal_comms.push_back (comm);
118
125
#endif
0 commit comments