|
| 1 | +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: BSD-3-Clause |
| 4 | + |
| 5 | +#include "ginkgo/core/distributed/neighborhood_communicator.hpp" |
| 6 | + |
| 7 | +#include <ginkgo/core/base/precision_dispatch.hpp> |
| 8 | +#include <ginkgo/core/matrix/dense.hpp> |
| 9 | + |
| 10 | + |
| 11 | +namespace gko { |
| 12 | +namespace experimental { |
| 13 | +namespace mpi { |
| 14 | + |
| 15 | + |
| 16 | +/** |
| 17 | + * \brief Computes the inverse envelope (target-ids, sizes) for a given |
| 18 | + * one-sided communication pattern. |
| 19 | + * |
| 20 | + * \param exec the executor, this will always use the host executor |
| 21 | + * \param comm communicator |
| 22 | + * \param ids target ids of the one-sided operation |
| 23 | + * \param sizes number of elements send to each id |
| 24 | + * |
| 25 | + * \return the inverse envelope consisting of the target-ids and the sizes |
| 26 | + */ |
| 27 | +std::tuple<std::vector<comm_index_type>, std::vector<comm_index_type>> |
| 28 | +communicate_inverse_envelope(std::shared_ptr<const Executor> exec, |
| 29 | + mpi::communicator comm, |
| 30 | + const std::vector<comm_index_type>& ids, |
| 31 | + const std::vector<comm_index_type>& sizes) |
| 32 | +{ |
| 33 | + auto host_exec = exec->get_master(); |
| 34 | + std::vector<comm_index_type> inverse_sizes_full(comm.size()); |
| 35 | + mpi::window<comm_index_type> window(host_exec, inverse_sizes_full.data(), |
| 36 | + inverse_sizes_full.size(), comm, |
| 37 | + sizeof(comm_index_type), MPI_INFO_ENV); |
| 38 | + window.fence(); |
| 39 | + for (int i = 0; i < ids.size(); ++i) { |
| 40 | + window.put(host_exec, sizes.data() + i, 1, ids[i], comm.rank(), 1); |
| 41 | + } |
| 42 | + window.fence(); |
| 43 | + |
| 44 | + std::vector<comm_index_type> inverse_sizes; |
| 45 | + std::vector<comm_index_type> inverse_ids; |
| 46 | + for (int i = 0; i < inverse_sizes_full.size(); ++i) { |
| 47 | + if (inverse_sizes_full[i] > 0) { |
| 48 | + inverse_ids.push_back(i); |
| 49 | + inverse_sizes.push_back(inverse_sizes_full[i]); |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + return std::make_tuple(std::move(inverse_ids), std::move(inverse_sizes)); |
| 54 | +} |
| 55 | + |
| 56 | + |
| 57 | +/** |
| 58 | + * Creates a distributed graph communicator based on the input sources and |
| 59 | + * destinations. |
| 60 | + * |
| 61 | + * The graph is unweighted and has the same rank ordering as the input |
| 62 | + * communicator. |
| 63 | + */ |
| 64 | +mpi::communicator create_neighborhood_comm( |
| 65 | + mpi::communicator base, const std::vector<comm_index_type>& sources, |
| 66 | + const std::vector<comm_index_type>& destinations) |
| 67 | +{ |
| 68 | + auto in_degree = static_cast<comm_index_type>(sources.size()); |
| 69 | + auto out_degree = static_cast<comm_index_type>(destinations.size()); |
| 70 | + |
| 71 | + // adjacent constructor guarantees that querying sources/destinations |
| 72 | + // will result in the array having the same order as defined here |
| 73 | + MPI_Comm graph_comm; |
| 74 | + MPI_Info info; |
| 75 | + GKO_ASSERT_NO_MPI_ERRORS(MPI_Info_dup(MPI_INFO_ENV, &info)); |
| 76 | + GKO_ASSERT_NO_MPI_ERRORS(MPI_Dist_graph_create_adjacent( |
| 77 | + base.get(), in_degree, sources.data(), |
| 78 | + in_degree ? MPI_UNWEIGHTED : MPI_WEIGHTS_EMPTY, out_degree, |
| 79 | + destinations.data(), out_degree ? MPI_UNWEIGHTED : MPI_WEIGHTS_EMPTY, |
| 80 | + info, false, &graph_comm)); |
| 81 | + GKO_ASSERT_NO_MPI_ERRORS(MPI_Info_free(&info)); |
| 82 | + |
| 83 | + return mpi::communicator::create_owning(graph_comm, |
| 84 | + base.force_host_buffer()); |
| 85 | +} |
| 86 | + |
| 87 | + |
| 88 | +std::unique_ptr<CollectiveCommunicator> |
| 89 | +NeighborhoodCommunicator::create_inverse() const |
| 90 | +{ |
| 91 | + auto base_comm = this->get_base_communicator(); |
| 92 | + distributed::comm_index_type num_sources; |
| 93 | + distributed::comm_index_type num_destinations; |
| 94 | + distributed::comm_index_type weighted; |
| 95 | + GKO_ASSERT_NO_MPI_ERRORS(MPI_Dist_graph_neighbors_count( |
| 96 | + comm_.get(), &num_sources, &num_destinations, &weighted)); |
| 97 | + |
| 98 | + std::vector<distributed::comm_index_type> sources(num_sources); |
| 99 | + std::vector<distributed::comm_index_type> destinations(num_destinations); |
| 100 | + GKO_ASSERT_NO_MPI_ERRORS(MPI_Dist_graph_neighbors( |
| 101 | + comm_.get(), num_sources, sources.data(), MPI_UNWEIGHTED, |
| 102 | + num_destinations, destinations.data(), MPI_UNWEIGHTED)); |
| 103 | + |
| 104 | + return std::make_unique<NeighborhoodCommunicator>( |
| 105 | + base_comm, destinations, send_sizes_, send_offsets_, sources, |
| 106 | + recv_sizes_, recv_offsets_); |
| 107 | +} |
| 108 | + |
| 109 | + |
| 110 | +comm_index_type NeighborhoodCommunicator::get_recv_size() const |
| 111 | +{ |
| 112 | + return recv_offsets_.back(); |
| 113 | +} |
| 114 | + |
| 115 | + |
| 116 | +comm_index_type NeighborhoodCommunicator::get_send_size() const |
| 117 | +{ |
| 118 | + return send_offsets_.back(); |
| 119 | +} |
| 120 | + |
| 121 | + |
| 122 | +NeighborhoodCommunicator::NeighborhoodCommunicator( |
| 123 | + communicator base, const std::vector<distributed::comm_index_type>& sources, |
| 124 | + const std::vector<comm_index_type>& recv_sizes, |
| 125 | + const std::vector<comm_index_type>& recv_offsets, |
| 126 | + const std::vector<distributed::comm_index_type>& destinations, |
| 127 | + const std::vector<comm_index_type>& send_sizes, |
| 128 | + const std::vector<comm_index_type>& send_offsets) |
| 129 | + : CollectiveCommunicator(base), comm_(MPI_COMM_NULL) |
| 130 | +{ |
| 131 | + comm_ = create_neighborhood_comm(base, sources, destinations); |
| 132 | + send_sizes_ = send_sizes; |
| 133 | + send_offsets_ = send_offsets; |
| 134 | + recv_sizes_ = recv_sizes; |
| 135 | + recv_offsets_ = recv_offsets; |
| 136 | +} |
| 137 | + |
| 138 | + |
| 139 | +NeighborhoodCommunicator::NeighborhoodCommunicator(communicator base) |
| 140 | + : CollectiveCommunicator(std::move(base)), |
| 141 | + comm_(MPI_COMM_SELF), |
| 142 | + send_sizes_(), |
| 143 | + send_offsets_(1), |
| 144 | + recv_sizes_(), |
| 145 | + recv_offsets_(1) |
| 146 | +{ |
| 147 | + // ensure that comm_ always has the correct topology |
| 148 | + std::vector<comm_index_type> non_nullptr(1); |
| 149 | + non_nullptr.resize(0); |
| 150 | + comm_ = create_neighborhood_comm(this->get_base_communicator(), non_nullptr, |
| 151 | + non_nullptr); |
| 152 | +} |
| 153 | + |
| 154 | + |
| 155 | +request NeighborhoodCommunicator::i_all_to_all_v( |
| 156 | + std::shared_ptr<const Executor> exec, const void* send_buffer, |
| 157 | + MPI_Datatype send_type, void* recv_buffer, MPI_Datatype recv_type) const |
| 158 | +{ |
| 159 | + auto guard = exec->get_scoped_device_id_guard(); |
| 160 | + request req; |
| 161 | + GKO_ASSERT_NO_MPI_ERRORS(MPI_Ineighbor_alltoallv( |
| 162 | + send_buffer, send_sizes_.data(), send_offsets_.data(), send_type, |
| 163 | + recv_buffer, recv_sizes_.data(), recv_offsets_.data(), recv_type, |
| 164 | + comm_.get(), req.get())); |
| 165 | + return req; |
| 166 | +} |
| 167 | + |
| 168 | + |
| 169 | +std::unique_ptr<CollectiveCommunicator> |
| 170 | +NeighborhoodCommunicator::create_with_same_type( |
| 171 | + communicator base, const distributed::index_map_variant& imap) const |
| 172 | +{ |
| 173 | + return std::visit( |
| 174 | + [base](const auto& imap) { |
| 175 | + return std::unique_ptr<CollectiveCommunicator>( |
| 176 | + new NeighborhoodCommunicator(base, imap)); |
| 177 | + }, |
| 178 | + imap); |
| 179 | +} |
| 180 | + |
| 181 | + |
| 182 | +template <typename LocalIndexType, typename GlobalIndexType> |
| 183 | +NeighborhoodCommunicator::NeighborhoodCommunicator( |
| 184 | + communicator base, |
| 185 | + const distributed::index_map<LocalIndexType, GlobalIndexType>& imap) |
| 186 | + : CollectiveCommunicator(base), |
| 187 | + comm_(MPI_COMM_SELF), |
| 188 | + recv_sizes_(imap.get_remote_target_ids().get_size()), |
| 189 | + recv_offsets_(recv_sizes_.size() + 1), |
| 190 | + send_offsets_(1) |
| 191 | +{ |
| 192 | + auto exec = imap.get_executor(); |
| 193 | + if (!exec) { |
| 194 | + return; |
| 195 | + } |
| 196 | + auto host_exec = exec->get_master(); |
| 197 | + |
| 198 | + auto recv_target_ids_arr = |
| 199 | + make_temporary_clone(host_exec, &imap.get_remote_target_ids()); |
| 200 | + auto remote_idx_offsets_arr = make_temporary_clone( |
| 201 | + host_exec, &imap.get_remote_global_idxs().get_offsets()); |
| 202 | + std::vector<comm_index_type> recv_target_ids( |
| 203 | + recv_target_ids_arr->get_size()); |
| 204 | + std::copy_n(recv_target_ids_arr->get_const_data(), |
| 205 | + recv_target_ids_arr->get_size(), recv_target_ids.begin()); |
| 206 | + for (size_type seg_id = 0; |
| 207 | + seg_id < imap.get_remote_global_idxs().get_segment_count(); ++seg_id) { |
| 208 | + recv_sizes_[seg_id] = |
| 209 | + remote_idx_offsets_arr->get_const_data()[seg_id + 1] - |
| 210 | + remote_idx_offsets_arr->get_const_data()[seg_id]; |
| 211 | + } |
| 212 | + auto send_envelope = |
| 213 | + communicate_inverse_envelope(exec, base, recv_target_ids, recv_sizes_); |
| 214 | + const auto& send_target_ids = std::get<0>(send_envelope); |
| 215 | + send_sizes_ = std::move(std::get<1>(send_envelope)); |
| 216 | + |
| 217 | + send_offsets_.resize(send_sizes_.size() + 1); |
| 218 | + std::partial_sum(send_sizes_.begin(), send_sizes_.end(), |
| 219 | + send_offsets_.begin() + 1); |
| 220 | + std::partial_sum(recv_sizes_.begin(), recv_sizes_.end(), |
| 221 | + recv_offsets_.begin() + 1); |
| 222 | + |
| 223 | + comm_ = create_neighborhood_comm(base, recv_target_ids, send_target_ids); |
| 224 | +} |
| 225 | + |
| 226 | + |
| 227 | +#define GKO_DECLARE_NEIGHBORHOOD_CONSTRUCTOR(LocalIndexType, GlobalIndexType) \ |
| 228 | + NeighborhoodCommunicator::NeighborhoodCommunicator( \ |
| 229 | + communicator base, \ |
| 230 | + const distributed::index_map<LocalIndexType, GlobalIndexType>& imap) |
| 231 | + |
| 232 | +GKO_INSTANTIATE_FOR_EACH_LOCAL_GLOBAL_INDEX_TYPE( |
| 233 | + GKO_DECLARE_NEIGHBORHOOD_CONSTRUCTOR); |
| 234 | + |
| 235 | +#undef GKO_DECLARE_NEIGHBORHOOD_CONSTRUCTOR |
| 236 | + |
| 237 | + |
| 238 | +} // namespace mpi |
| 239 | +} // namespace experimental |
| 240 | +} // namespace gko |
0 commit comments