Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
minmingzhu committed Jul 16, 2024
1 parent 2959da4 commit e0e8ef2
Show file tree
Hide file tree
Showing 2 changed files with 273 additions and 4 deletions.
112 changes: 108 additions & 4 deletions mllib-dal/src/main/native/OneCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
#include "Logger.h"
#include "OneCCL.h"
#include "com_intel_oap_mllib_OneCCL__.h"
#include "store.hpp"

#define STORE_TIMEOUT_SEC 120
#define KVS_CREATE_SUCCESS 0
#define KVS_CREATE_FAILURE -1

extern const size_t ccl_root = 0;

Expand All @@ -44,20 +49,69 @@ std::vector<ccl::shared_ptr_class<ccl::kvs>> g_kvs;

ccl::communicator &getComm() { return g_comms[0]; }
ccl::shared_ptr_class<ccl::kvs> &getKvs() { return g_kvs[0]; }
std::shared_ptr<file_store> store;

JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1init(
JNIEnv *env, jobject obj, jint size, jint rank, jstring ip_port, jstring name,
jobject param) {

logger::println(logger::INFO, "OneCCL (native): init");

store = std::make_shared<file_store>(
kvs_param, rank, std::chrono::seconds(STORE_TIMEOUT_SEC));

const char *str = env->GetStringUTFChars(ip_port, 0);
ccl::string ccl_ip_port(str);
const char *str_name = env->GetStringUTFChars(name, 0);
ccl::string ccl_name(str_name);

auto &singletonCCLInit = CCLInitSingleton::get(size, rank, ccl_ip_port, ccl_name);
auto t1 = std::chrono::high_resolution_clock::now();
ccl::init();

auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count();

logger::println(logger::INFO, "OneCCL singleton init took %f secs",
duration / 1000);
logger::Logger::getInstance(name).printLogToFile("rankID was %d, OneCCL singleton init took %f secs.", rank, duration / 1000 );

if (create_kvs_by_store(store, rank, kvs, ccl_name) != KVS_CREATE_SUCCESS) {
std::cout << "can not create kvs by store" << std::endl;
return -1;
}

t1 = std::chrono::high_resolution_clock::now();
logger::println(logger::INFO, "OneCCL (native): create_kvs_attr");

auto kvs_attr = ccl::create_kvs_attr();

kvs_attr.set<ccl::kvs_attr_id::ip_port>(ccl_ip_port);
logger::println(logger::INFO, "OneCCL (native): create_main_kvs");

ccl::shared_ptr_class<ccl::kvs> kvs = ccl::create_main_kvs(kvs_attr);
logger::println(logger::INFO, "OneCCL (native): g_ccl_kvs.push_back(kvs)");

{
std::lock_guard<std::mutex> lock(g_mtx);
g_kvs.push_back(kvs);
}
logger::println(logger::INFO, "OneCCL (native): ccl::create_communicator(size, rank, kvs)");
logger::println(logger::INFO, "ccl::create_communicator %d ,%d", size, rank);
auto comm = ccl::create_communicator(size, rank, kvs);
{
std::lock_guard<std::mutex> lock(g_mtx);
g_comms.push_back(ccl::create_communicator(size, rank, kvs));
}
logger::println(logger::INFO, "OneCCL (native): ccl::create_communicator finished");

t2 = std::chrono::high_resolution_clock::now();
duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::println(logger::INFO, "OneCCL (native): init took %f secs",
duration / 1000);
logger::Logger::getInstance(name).printLogToFile("rankID was %d, OneCCL create communicator took %f secs.", rank, duration / 1000 );


rank_id = getComm().rank();
comm_size = getComm().size();
Expand Down Expand Up @@ -91,8 +145,12 @@ Java_com_intel_oap_mllib_OneCCL_00024_c_1initDpcpp(JNIEnv *env, jobject) {
JNIEXPORT void JNICALL
Java_com_intel_oap_mllib_OneCCL_00024_c_1cleanup(JNIEnv *env, jobject obj) {
logger::printerrln(logger::INFO, "OneCCL (native): cleanup");
g_kvs.pop_back();
g_comms.pop_back();
std::cout << "Size after clear: " << g_kvs.size() << ", Capacity: " << g_kvs.capacity() << std::endl;
g_kvs.clear();
std::cout << "Size after clear: " << g_kvs.size() << ", Capacity: " << g_kvs.capacity() << std::endl;
std::cout << "Size after clear: " << g_comms.size() << ", Capacity: " << g_comms.capacity() << std::endl;
g_comms.clear();
std::cout << "Size after clear: " << g_comms.size() << ", Capacity: " << g_comms.capacity() << std::endl;
}

JNIEXPORT jboolean JNICALL
Expand Down Expand Up @@ -223,3 +281,49 @@ JNIEXPORT jint JNICALL Java_com_intel_oap_mllib_OneCCL_00024_c_1getAvailPort(

return port;
}

static int create_kvs_by_store(std::shared_ptr<file_store> store,
int rank,
ccl::shared_ptr_class<ccl::kvs>& kvs,
ccl::string name) {
logger::println(logger::INFO, "OneCCL (native): create_kvs_by_store ");
auto t1 = std::chrono::high_resolution_clock::now();
ccl::kvs::address_type main_addr;
auto start = std::chrono::system_clock::now();
if (rank == 0) {
kvs = ccl::create_main_kvs();
main_addr = kvs->get_address();
if (store->write((void*)main_addr.data(), main_addr.size()) < 0) {
logger::println(logger::INFO, "OneCCL (native): error occurred during write attempt");
kvs.reset();
return KVS_CREATE_FAILURE;
}
auto end = std::chrono::system_clock::now();
auto exec_time = =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(end -start)
.count();
logger::println(logger::INFO, "OneCCL (native): write to store time %f secs",
exec_time / 1000);
}
else {
if (store->read((void*)main_addr.data(), main_addr.size()) < 0) {
logger::println(logger::INFO, "OneCCL (native): error occurred during read attempt");
kvs.reset();
return KVS_CREATE_FAILURE;
}
auto end = std::chrono::system_clock::now();
auto exec_time = =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(end -start)
.count();
logger::println(logger::INFO, "OneCCL (native): read from store time %f secs",
exec_time / 1000);
kvs = ccl::create_kvs(main_addr);
}
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
(float)std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
logger::Logger::getInstance(name).printLogToFile("rankID was %d, OneCCL create communicator took %f secs.", rank, duration / 1000 );
return KVS_CREATE_SUCCESS;
}

165 changes: 165 additions & 0 deletions mllib-dal/src/main/native/store.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
#pragma once

#include <chrono>
#include <mutex>
#include <string>
#include <sys/file.h>
#include <thread>
#include <unistd.h>
#include <vector>

#define CHECK(ret, msg) \
if ((ret) < 0) { \
throw std::system_error(errno, std::system_category(), msg); \
}

class base_store {
public:
base_store(){};

virtual ~base_store(){};

virtual int write(const void* data, size_t size) = 0;

virtual int read(void* data, size_t size) = 0;
};

class file_store : public base_store {
public:
file_store(const file_store& other) = delete;
file_store& operator=(const file_store& other) = delete;
file_store(std::string path, int rank, const std::chrono::seconds& timeout)
: base_store(),
path(path),
rank(rank),
pos(0),
fd(-1),
timeout(timeout){};

virtual ~file_store() {
if (rank == 0)
std::remove(path.c_str());
};

void release_resources() {
try {
CHECK(flock(fd, LOCK_UN), "Unlocking file: ");
}
catch (const std::system_error& e) {
fprintf(stderr, "%d\n%s\n", e.code().value(), e.what());
}

close(fd);
fd = -1;
}

int write(const void* data, size_t size) override {
int ret = 0;
std::unique_lock<std::mutex> locker(mtx);
fd = open(path.c_str(), O_CREAT | O_RDWR, 0644);
CHECK(fd, "Open file to write into (" + path + "): ");

try {
CHECK(flock(fd, LOCK_EX), "Setting exclusive rights for writing to the file: ");
CHECK(lseek(fd, 0, SEEK_END), "Setting a cursor at the EOF: ");

// writing into the file
while (size > 0) {
auto wr_v = ::write(fd, data, size);
CHECK(wr_v, "An error occured while writing to the file: ");
data = (uint8_t*)data + wr_v;
size -= wr_v;
}
CHECK(fsync(fd), "Flushing file content: ");
}
catch (const std::system_error& e) {
fprintf(stderr, "%d\n%s\n", e.code().value(), e.what());
ret = -1;
}

release_resources();
return ret;
};

int read(void* data, size_t size) override {
const auto time_start = std::chrono::steady_clock::now();
while (1) {
std::unique_lock<std::mutex> locker(mtx);
fd = open(path.c_str(), O_RDONLY);
if (fd < 0 && errno == ENOENT) {
// file might not exist yet
const auto time_passed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - time_start);
if (time_passed > timeout) {
throw std::runtime_error("Timeout " + std::to_string(timeout.count()) +
"s waiting for the file " + path + " to open");
}
std::this_thread::sleep_for(std::chrono::milliseconds(10 * rank));
continue;
}
else {
CHECK(fd, "Open the file to read from (" + path + "): ");
}

try {
CHECK(flock(fd, LOCK_SH), "Setting shared rights for reading the file: ");

auto start = lseek(fd, 0, SEEK_SET);
CHECK(start, "Setting the cursor at the beginning of the file: ");

// find the real size of the file
auto len = lseek(fd, 0, SEEK_END);
CHECK(len, "Setting the cursor at the EOF: ");

if (len == start) {
// nothing has been written yet
release_resources();
locker.unlock();
const auto time_passed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - time_start);
if (time_passed > timeout) {
throw std::runtime_error("Timeout " + std::to_string(timeout.count()) +
"s waiting for the file " + path + " to read");
}
std::this_thread::sleep_for(std::chrono::milliseconds(10 * rank));
continue;
}

// start from where we stopped last time
start = lseek(fd, pos, SEEK_SET);
CHECK(start, "Setting the cursor at the last known position: ");

// if there are still some bytes to read
if (len > start && size > 0) {
size -= len;
while (len > 0) {
auto rd = ::read(fd, data, len);
CHECK(rd, "An error occured while reading the file: ")
data = (uint8_t*)data + rd;
len -= rd;
}
pos = lseek(fd, 0, SEEK_CUR);
CHECK(pos, "Saving the cursor current position: ");
}
else {
release_resources();
break;
}
}
catch (const std::system_error& e) {
fprintf(stderr, "%d\n%s\n", e.code().value(), e.what());
release_resources();
return -1;
}
}
return 0;
};

protected:
std::string path;
int rank;
off_t pos;
int fd;
std::chrono::seconds timeout;
std::mutex mtx;
};

0 comments on commit e0e8ef2

Please sign in to comment.