Skip to content

Commit 882cf68

Browse files
Binyang2014seagaterchhwang
authored
[Cherry-pick] NVLS support for NCCL API (#410) (#425)
Co-authored-by: Qinghua Zhou <[email protected]> Co-authored-by: Changho Hwang <[email protected]>
1 parent a30ce7c commit 882cf68

File tree

5 files changed

+91
-11
lines changed

5 files changed

+91
-11
lines changed

apps/nccl/include/nccl.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ typedef struct ncclConfig_v21700 {
6969
NCCL_CONFIG_UNDEF_INT /* splitShare */ \
7070
}
7171

72+
/* NCCL malloc and free function for all types of NCCL optimizations
73+
* (e.g. user buffer registration). The actual allocated size might
74+
* be larger than requested due to granularity requirement. */
75+
ncclResult_t ncclMemAlloc(void** ptr, size_t size);
76+
ncclResult_t pncclMemAlloc(void** ptr, size_t size);
77+
78+
ncclResult_t ncclMemFree(void* ptr);
79+
ncclResult_t pncclMemFree(void* ptr);
80+
7281
/* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer.
7382
* This integer is coded with the MAJOR, MINOR and PATCH level of the
7483
* NCCL library

apps/nccl/src/nccl.cu

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <mscclpp/executor.hpp>
99
#include <mscclpp/sm_channel.hpp>
1010
#include <mscclpp/sm_channel_device.hpp>
11+
#include <mscclpp/utils.hpp>
1112
#include <sstream>
1213
#include <unordered_map>
1314
#include <vector>
@@ -33,6 +34,9 @@
3334
// mscclpp::Transport::IB3, mscclpp::Transport::IB4, mscclpp::Transport::IB5,
3435
// mscclpp::Transport::IB6, mscclpp::Transport::IB7};
3536

37+
// Declare the global map to store associations between raw pointer and shared pointer
38+
static std::unordered_map<void*, std::shared_ptr<char>> ptrMap;
39+
3640
struct channelKey {
3741
const void* buff;
3842
size_t bytes;
@@ -113,7 +117,7 @@ static size_t ncclTypeSize(ncclDataType_t type) {
113117
return 0;
114118
}
115119

116-
double parseSize(const char* value) {
120+
static double parseSize(const char* value) {
117121
std::string valueStr(value);
118122
std::istringstream iss(valueStr);
119123
long long int units;
@@ -644,3 +648,59 @@ NCCL_API ncclResult_t ncclGroupEnd() {
644648
// Do nothing
645649
return ncclSuccess;
646650
}
651+
652+
NCCL_API ncclResult_t ncclCommRegister(const ncclComm_t, void*, size_t, void**) {
653+
// TODO: Implementation
654+
return ncclSuccess;
655+
}
656+
657+
NCCL_API ncclResult_t ncclCommDeregister(const ncclComm_t, void*) {
658+
// TODO: Implementation
659+
return ncclSuccess;
660+
}
661+
662+
ncclResult_t ncclMemAlloc(void** ptr, size_t size) {
663+
// Allocate memory using mscclpp::allocSharedPhysicalCuda
664+
if (ptr == nullptr || size == 0) {
665+
return ncclInvalidArgument;
666+
}
667+
std::shared_ptr<char> sharedPtr;
668+
try {
669+
if (mscclpp::isNvlsSupported()) {
670+
sharedPtr = mscclpp::allocSharedPhysicalCuda<char>(size);
671+
} else {
672+
sharedPtr = mscclpp::allocExtSharedCuda<char>(size);
673+
}
674+
if (sharedPtr == nullptr) {
675+
return ncclSystemError;
676+
}
677+
} catch (const mscclpp::Error& e) {
678+
if (e.getErrorCode() == mscclpp::ErrorCode::InvalidUsage) {
679+
return ncclInvalidUsage;
680+
} else {
681+
return ncclInternalError;
682+
}
683+
} catch (const mscclpp::CudaError& e) {
684+
return ncclUnhandledCudaError;
685+
} catch (const mscclpp::CuError& e) {
686+
return ncclUnhandledCudaError;
687+
} catch (const mscclpp::BaseError& e) {
688+
return ncclInternalError;
689+
}
690+
ptrMap[sharedPtr.get()] = sharedPtr;
691+
692+
// Return the pointer
693+
*ptr = sharedPtr.get();
694+
return ncclSuccess;
695+
}
696+
697+
ncclResult_t ncclMemFree(void* ptr) {
698+
auto ptrIt = ptrMap.find(ptr);
699+
if (ptrIt != ptrMap.end()) {
700+
ptrMap.erase(ptrIt);
701+
return ncclSuccess;
702+
}
703+
704+
// Pointer not found
705+
return ncclInvalidUsage;
706+
}

src/executor/execution_plan.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,17 @@ std::vector<ChannelInfo> ExecutionPlan::Impl::getUnpairedChannelInfos(int rank,
141141
return unpaired;
142142
}
143143

144-
std::vector<NvlsInfo> ExecutionPlan::Impl::getNvlsInfos(int rank) const { return this->nvlsInfos.at(rank); }
144+
std::vector<NvlsInfo> ExecutionPlan::Impl::getNvlsInfos(int rank, size_t sendBuffserSize, size_t recvBufferSize) const {
145+
if (sendBuffserSize == 0 && recvBufferSize == 0) {
146+
return this->nvlsInfos.at(rank);
147+
}
148+
size_t chunkSize = this->getUpperBoundChunkSize(rank, sendBuffserSize, recvBufferSize);
149+
std::vector<NvlsInfo> infos = this->nvlsInfos.at(rank);
150+
for (auto& info : infos) {
151+
info.bufferSize = info.bufferSize * chunkSize;
152+
}
153+
return infos;
154+
}
145155

146156
std::vector<int> ExecutionPlan::Impl::getConnectedPeers(int rank) const {
147157
std::set<int> peers;
@@ -272,7 +282,7 @@ void ExecutionPlan::Impl::parseChannels(
272282
NvlsInfo info;
273283
info.bufferType = convertToBufferType(channel["buff"]);
274284
for (const auto& group : channel["rankGroups"]) {
275-
info.bufferSize = (int)group["size"] * this->getUpperBoundChunkSize(rank, this->inputSize, this->outputSize);
285+
info.bufferSize = (int)group["size"];
276286
info.ranks.clear();
277287
for (int rank : group["ranks"]) {
278288
info.ranks.push_back(rank);

src/executor/executor.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,10 @@ struct Executor::Impl {
180180
context.scratchBufferSize = scratchBufferSize;
181181
context.proxyService = std::make_shared<ProxyService>();
182182
context.nthreadsPerBlock = plan.impl_->getNThreadsPerBlock();
183-
this->setupConnections(context, rank, plan);
183+
this->setupConnections(context, rank, plan, sendMemRange, recvMemRange);
184184
this->setupRegisteredMemories(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
185185
this->setupChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
186-
this->setupNvlsChannels(context, sendbuff, recvbuff, rank, plan);
186+
this->setupNvlsChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan);
187187
this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan);
188188
context.deviceExecutionPlansBuffers[devicePlanKey] =
189189
allocExtSharedCuda<char>(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan));
@@ -214,7 +214,8 @@ struct Executor::Impl {
214214
return flags;
215215
};
216216

217-
void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan) {
217+
void setupConnections(ExecutionContext& context, int rank, const ExecutionPlan& plan, size_t sendBufferSize,
218+
size_t recvBufferSize) {
218219
std::vector<int> connectedPeers = plan.impl_->getConnectedPeers(rank);
219220
std::vector<mscclpp::NonblockingFuture<std::shared_ptr<mscclpp::Connection>>> connectionFutures;
220221
for (int peer : connectedPeers) {
@@ -227,7 +228,7 @@ struct Executor::Impl {
227228
context.connections[connectedPeers[i]] = connectionFutures[i].get();
228229
}
229230

230-
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank);
231+
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize);
231232
for (const NvlsInfo& info : nvlsInfos) {
232233
std::shared_ptr<NvlsConnection> nvlsConnection =
233234
mscclpp::connectNvlsCollective(this->comm, info.ranks, info.bufferSize);
@@ -351,9 +352,9 @@ struct Executor::Impl {
351352
}
352353
}
353354

354-
void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, int rank,
355-
const ExecutionPlan& plan) {
356-
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank);
355+
void setupNvlsChannels(ExecutionContext& context, void* sendbuff, void* recvbuff, size_t sendBufferSize,
356+
size_t recvBufferSize, int rank, const ExecutionPlan& plan) {
357+
std::vector<NvlsInfo> nvlsInfos = plan.impl_->getNvlsInfos(rank, sendBufferSize, recvBufferSize);
357358
for (size_t i = 0; i < nvlsInfos.size(); i++) {
358359
std::shared_ptr<NvlsConnection> nvlsConnection = context.nvlsConnections[i];
359360
NvlsInfo info = nvlsInfos[i];

src/include/execution_plan.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ struct ExecutionPlan::Impl {
6969
std::vector<ChannelInfo> getChannelInfos(int rank, BufferType bufferType) const;
7070
std::vector<ChannelInfo> getChannelInfosByDstRank(int rank, BufferType bufferType) const;
7171
std::vector<ChannelInfo> getUnpairedChannelInfos(int rank, int worldSize, ChannelType channelType);
72-
std::vector<NvlsInfo> getNvlsInfos(int rank) const;
72+
std::vector<NvlsInfo> getNvlsInfos(int rank, size_t sendBuffserSize = 0, size_t recvBufferSize = 0) const;
7373
std::vector<int> getConnectedPeers(int rank) const;
7474
std::vector<BufferType> getConnectedBufferTypes(int rank) const;
7575
size_t getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const;

0 commit comments

Comments
 (0)