Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Jul 29, 2024
1 parent 57d934f commit d65e29b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
45 changes: 40 additions & 5 deletions mllib-dal/src/main/native/CorrelationImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ JNIEXPORT jlong JNICALL
Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
JNIEnv *env, jobject obj, jint rank, jlong pNumTabData, jlong numRows, jlong numClos,
jint executorNum, jint executorCores, jint computeDeviceOrdinal,
jintArray gpuIdxArray, jstring breakdown_name, jobject resultObj) {
jintArray gpuIdxArray, jstring ip_port, jstring breakdown_name, jobject resultObj) {
logger::println(logger::INFO,
"oneDAL (native): use DPC++ kernels; device %s",
ComputeDeviceString[computeDeviceOrdinal].c_str());
Expand Down Expand Up @@ -322,18 +322,52 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(
jint *gpuIndices = env->GetIntArrayElements(gpuIdxArray, 0);
const char* cstr = env->GetStringUTFChars(breakdown_name, nullptr);
std::string c_breakdown_name(cstr);
const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);

auto t1 = std::chrono::high_resolution_clock::now();
logger::println(logger::INFO, "CCLInitSingleton name %s",
name);
ccl::init();

auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count();

logger::println(logger::INFO, "OneCCL singleton init took %f secs",
duration / 1000);
logger::Logger::getInstance(name).printLogToFile("rankID was %d, OneCCL singleton init took %f secs.", rank, duration / 1000 );


t1 = std::chrono::high_resolution_clock::now();
logger::println(logger::INFO, "OneCCL (native): create_kvs_attr");

auto kvs_attr = ccl::create_kvs_attr();

kvs_attr.set<ccl::kvs_attr_id::ip_port>(ccl_ip_port);
logger::println(logger::INFO, "OneCCL (native): create_main_kvs");

auto kvs = ccl::create_main_kvs(kvs_attr);
logger::println(logger::INFO, "OneCCL (native): g_ccl_kvs.push_back(kvs)");

t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);
logger::Logger::getInstance(name).printLogToFile("rankID was %d, OneCCL create communicator took %f secs.", rank, duration / 1000 );

// auto queue = getGPU(device, gpuIndices);
auto device = sycl::device(sycl::gpu_selector_v);
sycl::queue queue{device};

ccl::shared_ptr_class<ccl::kvs> &kvs = getKvs();
auto t1 = std::chrono::high_resolution_clock::now();
t1 = std::chrono::high_resolution_clock::now();
auto comm =
preview::spmd::make_communicator<preview::spmd::backend::ccl>(
queue, executorNum, rank, kvs);
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::Logger::getInstance(c_breakdown_name).printLogToFile("rankID was %d, create communicator took %f secs.", rank, duration / 1000 );
Expand All @@ -342,6 +376,7 @@ Java_com_intel_oap_mllib_stat_CorrelationDALImpl_cCorrelationTrainDAL(

env->ReleaseIntArrayElements(gpuIdxArray, gpuIndices, 0);
env->ReleaseStringUTFChars(breakdown_name, cstr);
env->ReleaseStringUTFChars(ip_port, str);
break;
}
#endif
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ class CorrelationDALImpl(
Iterator.empty
}.count()

coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, training_breakdown_name, storePath)
Iterator.empty
}.count()
if (useDevice == "CPU") {
coalescedTables.mapPartitionsWithIndex { (rank, table) =>
OneCCL.init(executorNum, rank, kvsIPPort, training_breakdown_name, storePath)
Iterator.empty
}.count()
}
corTimer.record("OneCCL Init")

val results = coalescedTables.mapPartitionsWithIndex { (rank, iter) =>
Expand Down Expand Up @@ -97,6 +99,7 @@ class CorrelationDALImpl(
executorCores,
computeDevice.ordinal(),
gpuIndices,
kvsIPPort,
training_breakdown_name,
result
)
Expand Down Expand Up @@ -148,6 +151,7 @@ class CorrelationDALImpl(
executorCores: Int,
computeDeviceOrdinal: Int,
gpuIndices: Array[Int],
kvsIPPort: String,
training_breakdown_name: String,
result: CorrelationResult): Long

Expand Down

0 comments on commit d65e29b

Please sign in to comment.