Skip to content

Commit 5777b98

Browse files
Merge pull request #1139 from arcaneframework/dev/gg-add-helper-function-for-broadcast
Mutualize arrays broadcast between ranks in 'MshParallelMeshReader'
2 parents 1fc97af + 92b53b3 commit 5777b98

File tree

1 file changed

+46
-70
lines changed

1 file changed

+46
-70
lines changed

Diff for: arcane/src/arcane/std/MshParallelMeshReader.cc

+46-70
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,43 @@ _readElementsFromFileAscii()
823823
/*---------------------------------------------------------------------------*/
824824
/*---------------------------------------------------------------------------*/
825825

826+
namespace
827+
{
828+
template <typename DataType> inline ArrayView<DataType>
829+
_broadcastArrayWithSize(IParallelMng* pm, UniqueArray<DataType>& values,
830+
UniqueArray<DataType>& work_values, Int32 dest_rank, Int64 size)
831+
{
832+
const Int32 my_rank = pm->commRank();
833+
ArrayView<DataType> view = values.view();
834+
if (my_rank != dest_rank) {
835+
work_values.resize(size);
836+
view = work_values.view();
837+
}
838+
pm->broadcast(view, dest_rank);
839+
return view;
840+
}
841+
/*!
842+
* \brief Broadcast un tableau et retourne une vue dessus.
843+
*
844+
* Si on est le rang \a dest_rank, alors on broadcast \a values.
845+
* Les autres rangs récupèrent la valeur dans \a work_values et
846+
* retournent une vue dessus.
847+
*/
848+
template <typename DataType> inline ArrayView<DataType>
849+
_broadcastArray(IParallelMng* pm, UniqueArray<DataType>& values,
850+
UniqueArray<DataType>& work_values, Int32 dest_rank)
851+
{
852+
const Int32 my_rank = pm->commRank();
853+
Int64 size = 0;
854+
// Envoie la taille
855+
if (dest_rank == my_rank)
856+
size = values.size();
857+
pm->broadcast(ArrayView<Int64>(1, &size), dest_rank);
858+
return _broadcastArrayWithSize(pm, values, work_values, dest_rank, size);
859+
}
860+
861+
} // namespace
862+
826863
void MshParallelMeshReader::
827864
_computeOwnCells(MeshV4ElementsBlock& block)
828865
{
@@ -841,28 +878,10 @@ _computeOwnCells(MeshV4ElementsBlock& block)
841878
const Int32 nb_part = m_parts_rank.size();
842879
for (Int32 i_part = 0; i_part < nb_part; ++i_part) {
843880
const Int32 dest_rank = m_parts_rank[i_part];
844-
ArrayView<Int64> connectivities_view;
845-
ArrayView<Int64> uids_view;
846-
847881
// Broadcast la i_part-ème partie des uids et connectivités des mailles
848-
{
849-
FixedArray<Int64, 2> nb_connectivity_and_uid;
850-
if (my_rank == dest_rank) {
851-
nb_connectivity_and_uid[0] = block.connectivities.size();
852-
nb_connectivity_and_uid[1] = block.uids.size();
853-
connectivities_view = block.connectivities;
854-
uids_view = block.uids;
855-
}
856-
pm->broadcast(nb_connectivity_and_uid.view(), dest_rank);
857-
if (my_rank != dest_rank) {
858-
connectivities.resize(nb_connectivity_and_uid[0]);
859-
uids.resize(nb_connectivity_and_uid[1]);
860-
connectivities_view = connectivities;
861-
uids_view = uids;
862-
}
863-
pm->broadcast(connectivities_view, dest_rank);
864-
pm->broadcast(uids_view, dest_rank);
865-
}
882+
ArrayView<Int64> connectivities_view = _broadcastArray(pm, block.connectivities, connectivities, dest_rank);
883+
ArrayView<Int64> uids_view = _broadcastArray(pm, block.uids, uids, dest_rank);
884+
866885
Int32 nb_item = uids_view.size();
867886
nodes_rank.resize(nb_item);
868887
nodes_rank.fill(-1);
@@ -920,34 +939,15 @@ _setNodesCoordinates()
920939
UniqueArray<Int32> local_ids;
921940

922941
IParallelMng* pm = m_parallel_mng;
923-
const Int32 my_rank = pm->commRank();
924942

925943
const IItemFamily* node_family = m_mesh->nodeFamily();
926944
VariableNodeReal3& nodes_coord_var(m_mesh->nodesCoordinates());
927945

928946
for (Int32 dest_rank : m_parts_rank) {
929-
Int32 nb_item = 0;
930-
if (my_rank == dest_rank) {
931-
nb_item = m_mesh_info.nodes_unique_id.size();
932-
}
933-
// Envoie le nombre de noeuds aux autres
934-
pm->broadcast(ArrayView<Int32>(1, &nb_item), dest_rank);
935-
ConstArrayView<Int64> uids;
936-
ConstArrayView<Real3> coords;
937-
if (my_rank == dest_rank) {
938-
pm->broadcast(m_mesh_info.nodes_unique_id, dest_rank);
939-
pm->broadcast(m_mesh_info.nodes_coordinates, dest_rank);
940-
uids = m_mesh_info.nodes_unique_id.view();
941-
coords = m_mesh_info.nodes_coordinates.view();
942-
}
943-
else {
944-
uids_storage.resize(nb_item);
945-
coords_storage.resize(nb_item);
946-
pm->broadcast(uids_storage, dest_rank);
947-
pm->broadcast(coords_storage, dest_rank);
948-
uids = uids_storage.view();
949-
coords = coords_storage.view();
950-
}
947+
ConstArrayView<Int64> uids = _broadcastArray(pm, m_mesh_info.nodes_unique_id, uids_storage, dest_rank);
948+
ConstArrayView<Real3> coords = _broadcastArray(pm, m_mesh_info.nodes_coordinates, coords_storage, dest_rank);
949+
950+
Int32 nb_item = uids.size();
951951
local_ids.resize(nb_item);
952952

953953
// Converti les uniqueId() en localId(). S'ils sont non nuls
@@ -1074,23 +1074,11 @@ void MshParallelMeshReader::
10741074
_addFaceGroup(MeshV4ElementsBlock& block, const String& group_name)
10751075
{
10761076
IParallelMng* pm = m_parallel_mng;
1077-
const Int32 my_rank = pm->commRank();
10781077
const Int32 item_nb_node = block.item_nb_node;
10791078

10801079
UniqueArray<Int64> connectivities;
10811080
for (Int32 dest_rank : m_parts_rank) {
1082-
Int64 nb_connectivity = 0;
1083-
ArrayView<Int64> connectivities_view;
1084-
if (my_rank == dest_rank) {
1085-
nb_connectivity = block.connectivities.size();
1086-
connectivities_view = block.connectivities;
1087-
}
1088-
pm->broadcast(ArrayView<Int64>(1, &nb_connectivity), dest_rank);
1089-
if (my_rank != dest_rank) {
1090-
connectivities.resize(nb_connectivity);
1091-
connectivities_view = connectivities;
1092-
}
1093-
pm->broadcast(connectivities_view, dest_rank);
1081+
ArrayView<Int64> connectivities_view = _broadcastArray(pm, block.connectivities, connectivities, dest_rank);
10941082
_addFaceGroupOnePart(connectivities_view, item_nb_node, group_name, block.index);
10951083
}
10961084
}
@@ -1181,22 +1169,10 @@ void MshParallelMeshReader::
11811169
_addCellOrNodeGroup(MeshV4ElementsBlock& block, const String& group_name, IItemFamily* family)
11821170
{
11831171
IParallelMng* pm = m_parallel_mng;
1184-
const Int32 my_rank = pm->commRank();
11851172

11861173
UniqueArray<Int64> uids;
11871174
for (Int32 dest_rank : m_parts_rank) {
1188-
Int64 nb_uid = 0;
1189-
ArrayView<Int64> uids_view;
1190-
if (my_rank == dest_rank) {
1191-
nb_uid = block.uids.size();
1192-
uids_view = block.uids;
1193-
}
1194-
pm->broadcast(ArrayView<Int64>(1, &nb_uid), dest_rank);
1195-
if (my_rank != dest_rank) {
1196-
uids.resize(nb_uid);
1197-
uids_view = uids;
1198-
}
1199-
pm->broadcast(uids_view, dest_rank);
1175+
ArrayView<Int64> uids_view = _broadcastArray(pm, block.uids, uids, dest_rank);
12001176
_addCellOrNodeGroupOnePart(uids_view, group_name, block.index, family);
12011177
}
12021178
}

0 commit comments

Comments
 (0)