Skip to content

Commit 121ccc1

Browse files
committed
update
1 parent 984ac8f commit 121ccc1

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

mllib-dal/src/main/native/CorrelationImpl.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,20 @@ static void doCorrelationOneAPICompute(
282282
}
283283
}
284284
#endif
285+
std::vector<sycl::device> test_gpus()
286+
{
287+
288+
auto platforms = sycl::platform::get_platforms();
289+
for (auto p : platforms) {
290+
auto devices = p.get_devices(sycl::info::device_type::gpu);
291+
if (!devices.empty()) {
292+
return devices;
293+
}
294+
}
295+
std::cout << "No GPUs!" << std::endl;
296+
exit(-3);
297+
return {};
298+
}
285299

286300
JNIEXPORT jlong JNICALL
287301
Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
@@ -319,7 +333,8 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
319333
"oneDAL (native): use GPU kernels with %d GPU(s) rankid %d", nGpu,
320334
rank);
321335
jint *gpuIndices = env->GetIntArrayElements(gpuIdxArray, 0);
322-
auto queue = getGPU(device, gpuIndices);
336+
auto gpus = test_gpus()
337+
// auto queue = getGPU(device, gpuIndices);
323338
// auto gpu_device = sycl::device(sycl::gpu_selector_v);
324339
// sycl::queue queue{gpu_device};
325340
const char* cstr = env->GetStringUTFChars(breakdown_name, nullptr);
@@ -358,6 +373,7 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
358373
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
359374
duration / 1000);
360375
logger::Logger::getInstance(c_breakdown_name).printLogToFile("rankID was %d, OneCCL create communicator took %f secs.", rank, duration / 1000 );
376+
sycl::queue queue{gpus[gpu_indices[0]]};
361377

362378
t1 = std::chrono::high_resolution_clock::now();
363379
auto comm =

0 commit comments

Comments
 (0)