@@ -282,6 +282,20 @@ static void doCorrelationOneAPICompute(
282282 }
283283}
284284#endif
285+ std::vector<sycl::device> test_gpus ()
286+ {
287+
288+ auto platforms = sycl::platform::get_platforms ();
289+ for (auto p : platforms) {
290+ auto devices = p.get_devices (sycl::info::device_type::gpu);
291+ if (!devices.empty ()) {
292+ return devices;
293+ }
294+ }
295+ std::cout << " No GPUs!" << std::endl;
296+ exit (-3 );
297+ return {};
298+ }
285299
286300JNIEXPORT jlong JNICALL
287301Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL (
@@ -319,7 +333,8 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
319333 " oneDAL (native): use GPU kernels with %d GPU(s) rankid %d" , nGpu,
320334 rank);
321335 jint *gpuIndices = env->GetIntArrayElements (gpuIdxArray, 0 );
322- auto queue = getGPU (device, gpuIndices);
336+ auto gpus = test_gpus ()
337+ // auto queue = getGPU(device, gpuIndices);
323338// auto gpu_device = sycl::device(sycl::gpu_selector_v);
324339// sycl::queue queue{gpu_device};
325340 const char * cstr = env->GetStringUTFChars (breakdown_name, nullptr );
@@ -358,6 +373,7 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
358373 logger::println (logger::INFO, " OneCCL (native): init took %f secs" ,
359374 duration / 1000 );
360375 logger::Logger::getInstance (c_breakdown_name).printLogToFile (" rankID was %d, OneCCL create communicator took %f secs." , rank, duration / 1000 );
376+ sycl::queue queue{gpus[gpu_indices[0 ]]};
361377
362378 t1 = std::chrono::high_resolution_clock::now ();
363379 auto comm =
0 commit comments