Skip to content

Commit c83395c

Browse files
authored
including sendRecv implementation and test (#27)
1 parent bb6f254 commit c83395c

File tree

4 files changed

+107
-0
lines changed

4 files changed

+107
-0
lines changed

include/faabric/scheduler/MpiWorld.h

+10
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ class MpiWorld
8686

8787
void awaitAsyncRequest(int requestId);
8888

89+
void sendRecv(uint8_t* sendBuffer,
90+
int sendcount,
91+
faabric_datatype_t* sendDataType,
92+
int recvRank,
93+
uint8_t* recvBuffer,
94+
int recvCount,
95+
faabric_datatype_t* recvDataType,
96+
int sendRank,
97+
MPI_Status* status);
98+
8999
void scatter(int sendRank,
90100
int recvRank,
91101
const uint8_t* sendBuffer,

src/proto/faabric.proto

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ message MPIMessage {
3535
ALLREDUCE = 8;
3636
ALLTOALL = 9;
3737
RMA_WRITE = 10;
38+
SENDRECV = 11;
3839
};
3940

4041
MPIMessageType messageType = 1;

src/scheduler/MpiWorld.cpp

+37
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,43 @@ void MpiWorld::send(int sendRank,
312312
}
313313
}
314314

315+
void MpiWorld::sendRecv(uint8_t* sendBuffer,
316+
int sendCount,
317+
faabric_datatype_t* sendDataType,
318+
int recvRank,
319+
uint8_t* recvBuffer,
320+
int recvCount,
321+
faabric_datatype_t* recvDataType,
322+
int sendRank,
323+
MPI_Status* status)
324+
{
325+
auto logger = faabric::util::getLogger();
326+
logger->trace("MPI - Sendrecv");
327+
328+
if (recvRank > this->size - 1) {
329+
throw std::runtime_error(fmt::format(
330+
"Receive rank {} bigger than world size {}", recvRank, this->size));
331+
}
332+
if (sendRank > this->size - 1) {
333+
throw std::runtime_error(fmt::format(
334+
"Send rank {} bigger than world size {}", sendRank, this->size));
335+
}
336+
337+
// Post async recv
338+
int recvId = irecv(recvRank, sendRank, recvBuffer, recvDataType, recvCount);
339+
// Then send the message
340+
// TODO change MPIMessage to MPIMessage::SENDRECV. This requires a change
341+
// in the signature of doISendRecv.
342+
send(sendRank,
343+
recvRank,
344+
sendBuffer,
345+
sendDataType,
346+
sendCount,
347+
faabric::MPIMessage::NORMAL);
348+
// And wait
349+
awaitAsyncRequest(recvId);
350+
}
351+
315352
void MpiWorld::broadcast(int sendRank,
316353
const uint8_t* buffer,
317354
faabric_datatype_t* dataType,

tests/test/scheduler/test_mpi_world.cpp

+59
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,65 @@ TEST_CASE("Test send and recv on same host", "[mpi]")
229229
}
230230
}
231231

232+
TEST_CASE("Test sendrecv", "[mpi]")
233+
{
234+
cleanFaabric();
235+
236+
auto msg = faabric::util::messageFactory(user, func);
237+
scheduler::MpiWorld world;
238+
world.create(msg, worldId, worldSize);
239+
240+
// Register two ranks
241+
int rankA = 1;
242+
int rankB = 2;
243+
world.registerRank(rankA);
244+
world.registerRank(rankB);
245+
246+
// Prepare data
247+
MPI_Status status{};
248+
std::vector<int> messageDataAB = { 0, 1, 2 };
249+
std::vector<int> messageDataBA = { 3, 2, 1, 0 };
250+
251+
// sendRecv is blocking, so we run two threads.
252+
// Run sendrecv from A
253+
std::vector<std::thread> threads;
254+
threads.emplace_back([&] {
255+
std::vector<int> recvBufferA(messageDataBA.size(), 0);
256+
world.sendRecv(BYTES(messageDataAB.data()),
257+
messageDataAB.size(),
258+
MPI_INT,
259+
rankB,
260+
BYTES(recvBufferA.data()),
261+
messageDataBA.size(),
262+
MPI_INT,
263+
rankA,
264+
&status);
265+
// Test integrity of results
266+
REQUIRE(recvBufferA == messageDataBA);
267+
});
268+
// Run sendrecv from B
269+
threads.emplace_back([&] {
270+
std::vector<int> recvBufferB(messageDataAB.size(), 0);
271+
world.sendRecv(BYTES(messageDataBA.data()),
272+
messageDataBA.size(),
273+
MPI_INT,
274+
rankA,
275+
BYTES(recvBufferB.data()),
276+
messageDataAB.size(),
277+
MPI_INT,
278+
rankB,
279+
&status);
280+
// Test integrity of results
281+
REQUIRE(recvBufferB == messageDataAB);
282+
});
283+
// Wait for both to finish
284+
for (auto& t : threads) {
285+
if (t.joinable()) {
286+
t.join();
287+
}
288+
}
289+
}
290+
232291
TEST_CASE("Test async send and recv", "[mpi]")
233292
{
234293
cleanFaabric();

0 commit comments

Comments
 (0)