@@ -282,6 +282,20 @@ static void doCorrelationOneAPICompute(
282
282
}
283
283
}
284
284
#endif
285
+ std::vector<sycl::device> test_gpus ()
286
+ {
287
+
288
+ auto platforms = sycl::platform::get_platforms ();
289
+ for (auto p : platforms) {
290
+ auto devices = p.get_devices (sycl::info::device_type::gpu);
291
+ if (!devices.empty ()) {
292
+ return devices;
293
+ }
294
+ }
295
+ std::cout << " No GPUs!" << std::endl;
296
+ exit (-3 );
297
+ return {};
298
+ }
285
299
286
300
JNIEXPORT jlong JNICALL
287
301
Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL (
@@ -319,7 +333,8 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
319
333
" oneDAL (native): use GPU kernels with %d GPU(s) rankid %d" , nGpu,
320
334
rank);
321
335
jint *gpuIndices = env->GetIntArrayElements (gpuIdxArray, 0 );
322
- auto queue = getGPU (device, gpuIndices);
336
+ auto gpus = test_gpus ()
337
+ // auto queue = getGPU(device, gpuIndices);
323
338
// auto gpu_device = sycl::device(sycl::gpu_selector_v);
324
339
// sycl::queue queue{gpu_device};
325
340
const char * cstr = env->GetStringUTFChars (breakdown_name, nullptr );
@@ -358,6 +373,7 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
358
373
logger::println (logger::INFO, " OneCCL (native): init took %f secs" ,
359
374
duration / 1000 );
360
375
logger::Logger::getInstance (c_breakdown_name).printLogToFile (" rankID was %d, OneCCL create communicator took %f secs." , rank, duration / 1000 );
376
+ sycl::queue queue{gpus[gpu_indices[0 ]]};
361
377
362
378
t1 = std::chrono::high_resolution_clock::now ();
363
379
auto comm =
0 commit comments