diff --git a/mllib-dal/src/main/native/GPU.h b/mllib-dal/src/main/native/GPU.h index 818d3ddb4..39ce61789 100644 --- a/mllib-dal/src/main/native/GPU.h +++ b/mllib-dal/src/main/native/GPU.h @@ -11,3 +11,5 @@ sycl::queue getAssignedGPU(const ComputeDevice device, ccl::communicator &comm, int size, int rankId, jint *gpu_indices, int n_gpu); sycl::queue getQueue(const ComputeDevice device); + +sycl::queue getGPU(const ComputeDevice device,jint *gpu_indices); diff --git a/mllib-dal/src/main/native/OneCCL.cpp b/mllib-dal/src/main/native/OneCCL.cpp index febe83b0e..5385f3ebb 100644 --- a/mllib-dal/src/main/native/OneCCL.cpp +++ b/mllib-dal/src/main/native/OneCCL.cpp @@ -153,9 +153,9 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init( // rank_id = getComm().rank(); // comm_size = getComm().size(); // -// jclass cls = env->GetObjectClass(param); -// jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J"); -// jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J"); + jclass cls = env->GetObjectClass(param); + jfieldID fid_comm_size = env->GetFieldID(cls, "commSize", "J"); + jfieldID fid_rank_id = env->GetFieldID(cls, "rankId", "J"); env->SetLongField(param, fid_comm_size, size); env->SetLongField(param, fid_rank_id, rank);