Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Jul 29, 2024
1 parent 3b2ac68 commit 57d934f
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions mllib-dal/src/main/native/CorrelationImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
#include "oneapi/dal/table/common.hpp"
#include "oneapi/dal/io/csv.hpp"
#include "Logger.h"
#include <cstdlib> // for getenv
//#include <cstdlib> // for getenv

using namespace std;
#ifdef CPU_GPU_PROFILE
Expand Down Expand Up @@ -203,28 +203,28 @@ static void doCorrelationOneAPICompute(
// } else {
// std::cout << "Environment variable not found." << std::endl;
// }
for (char **env = environ; *env != nullptr; ++env) {
std::cout << *env << std::endl;
}
// for (char **env = environ; *env != nullptr; ++env) {
// std::cout << *env << std::endl;
// }
const bool isRoot = (comm.get_rank() == ccl_root);
auto t1 = std::chrono::high_resolution_clock::now();
auto input_vec = file_path("/home/damon/storage/DataRoot/HiBench_CSV/Correlation/Input/4000000");
const auto train_data_file_name = data_path(input_vec[comm.get_rank()]);
cout << "rank id = " << comm.get_rank() << " File name: " << train_data_file_name << endl;
const auto htable = read<table>(queue, csv::data_source{ train_data_file_name });
comm.barrier();

// float *htableArray = reinterpret_cast<float *>(pNumTabData);
// logger::println(logger::INFO, "numRows was %d", numRows);
// logger::println(logger::INFO, "numClos was %d", numClos);
// auto t1 = std::chrono::high_resolution_clock::now();
// auto input_vec = file_path("/home/damon/storage/DataRoot/HiBench_CSV/Correlation/Input/4000000");
// const auto train_data_file_name = data_path(input_vec[comm.get_rank()]);
// cout << "rank id = " << comm.get_rank() << " File name: " << train_data_file_name << endl;
// const auto htable = read<table>(queue, csv::data_source{ train_data_file_name });
// comm.barrier();

float *htableArray = reinterpret_cast<float *>(pNumTabData);
logger::println(logger::INFO, "numRows was %d", numRows);
logger::println(logger::INFO, "numClos was %d", numClos);
auto t1 = std::chrono::high_resolution_clock::now();

// auto data = sycl::malloc_shared<float>(numRows * numClos, queue);
// std::cout << "table size : " << numRows * numClos << std::endl;
// logger::Logger::getInstance(breakdown_name).printLogToFile("rankID was %d, table size %ld.", comm.get_rank(), numRows * numClos );
// queue.memcpy(data, htableArray, sizeof(float) * numRows * numClos).wait();
// homogen_table htable{queue, data, numRows, numClos,
// detail::make_default_delete<const float>(queue)};
auto data = sycl::malloc_shared<float>(numRows * numClos, queue);
std::cout << "table size : " << numRows * numClos << std::endl;
logger::Logger::getInstance(breakdown_name).printLogToFile("rankID was %d, table size %ld.", comm.get_rank(), numRows * numClos );
queue.memcpy(data, htableArray, sizeof(float) * numRows * numClos).wait();
homogen_table htable{queue, data, numRows, numClos,
detail::make_default_delete<const float>(queue)};
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
Expand Down

0 comments on commit 57d934f

Please sign in to comment.