From 68cfb427b4a88bbf54258acf55e6c61c68c79fd6 Mon Sep 17 00:00:00 2001 From: minmingzhu Date: Wed, 21 Aug 2024 15:31:21 +0800 Subject: [PATCH] update --- mllib-dal/src/main/native/CorrelationImpl.cpp | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/mllib-dal/src/main/native/CorrelationImpl.cpp b/mllib-dal/src/main/native/CorrelationImpl.cpp index a4239a110..c4675cfae 100644 --- a/mllib-dal/src/main/native/CorrelationImpl.cpp +++ b/mllib-dal/src/main/native/CorrelationImpl.cpp @@ -202,13 +202,22 @@ static jlong doCorrelationOneAPICompute( logger::println(logger::INFO, "numClos was %d", numClos); auto t1 = std::chrono::high_resolution_clock::now(); - auto data = sycl::malloc_shared(numRows * numClos, queue); - std::cout << "table size : " << numRows * numClos << std::endl; - logger::Logger::getInstance(breakdown_name).printLogToFile("rankID was %d, table size %ld.", comm.get_rank(), numRows * numClos ); - queue.memcpy(data, htableArray, sizeof(float) * numRows * numClos).wait(); +// auto data = sycl::malloc_shared(numRows * numClos, queue); +// std::cout << "table size : " << numRows * numClos << std::endl; +// logger::Logger::getInstance(breakdown_name).printLogToFile("rankID was %d, table size %ld.", comm.get_rank(), numRows * numClos ); +// queue.memcpy(data, htableArray, sizeof(float) * numRows * numClos).wait(); + auto data = + oneapi::dal::array::empty(queue, numRows * numClos, sycl::usm::alloc::device); + + detail::memcpy_host2usm(queue, + data.get_mutable_data(), + htableArray, + sizeof(float) * numRows * numClos); + + homogen_table htable = homogen_table::wrap(data, numRows, numClos); freeArrayPtr(htableArray); - homogen_table htable{queue, data, numRows, numClos, - detail::make_default_delete(queue)}; +// homogen_table htable{queue, data, numRows, numClos, +// detail::make_default_delete(queue)}; auto t2 = std::chrono::high_resolution_clock::now(); auto duration = (float)std::chrono::duration_cast(t2 - t1)