Skip to content

Commit

Permalink
Lint + MPIComm on testing architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Dec 22, 2024
1 parent 07cd0f3 commit 224e6e2
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 27 deletions.
16 changes: 6 additions & 10 deletions ndsl/comm/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,14 @@ def all_reduce_per_element(
self.comm.Allreduce(input_quantity.data, output_quantity.data, op)

def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs):
with (
send_buffer(numpy_module.zeros, sendbuf) as send,
recv_buffer(numpy_module.zeros, recvbuf) as recv,
):
self.comm.Scatter(send, recv, **kwargs)
with send_buffer(numpy_module.zeros, sendbuf) as send:
with recv_buffer(numpy_module.zeros, recvbuf) as recv:
self.comm.Scatter(send, recv, **kwargs)

def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs):
with (
send_buffer(numpy_module.zeros, sendbuf) as send,
recv_buffer(numpy_module.zeros, recvbuf) as recv,
):
self.comm.Gather(send, recv, **kwargs)
with send_buffer(numpy_module.zeros, sendbuf) as send:
with recv_buffer(numpy_module.zeros, recvbuf) as recv:
self.comm.Gather(send, recv, **kwargs)

def scatter(
self,
Expand Down
11 changes: 5 additions & 6 deletions ndsl/stencils/testing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
CubedSphereCommunicator,
TileCommunicator,
)
from ndsl.comm.mpi import MPI
from ndsl.comm.mpi import MPI, MPIComm
from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner
from ndsl.dsl.dace.dace_config import DaceConfig
from ndsl.namelist import Namelist
Expand Down Expand Up @@ -308,7 +308,7 @@ def compute_grid_data(grid, namelist, backend, layout, topology_mode):
npx=namelist.npx,
npy=namelist.npy,
npz=namelist.npz,
communicator=get_communicator(MPI.COMM_WORLD, layout, topology_mode),
communicator=get_communicator(MPIComm(), layout, topology_mode),
backend=backend,
)

Expand Down Expand Up @@ -360,13 +360,12 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str):
metafunc.config
)
# get MPI environment
comm = MPI.COMM_WORLD
mpi_rank = comm.Get_rank()
comm = MPIComm()
savepoint_cases = parallel_savepoint_cases(
metafunc,
data_path,
namelist_filename,
mpi_rank,
comm.Get_rank(),
backend=backend,
comm=comm,
)
Expand All @@ -376,7 +375,7 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str):


def get_communicator(comm, layout, topology_mode):
if (MPI.COMM_WORLD.Get_size() > 1) and (topology_mode == "cubed-sphere"):
if (comm.Get_size() > 1) and (topology_mode == "cubed-sphere"):
partitioner = CubedSpherePartitioner(TilePartitioner(layout))
communicator = CubedSphereCommunicator(comm, partitioner)
else:
Expand Down
17 changes: 9 additions & 8 deletions ndsl/stencils/testing/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import ndsl.dsl.gt4py_utils as gt_utils
from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator
from ndsl.comm.mpi import MPI
from ndsl.comm.mpi import MPI, MPIComm
from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner
from ndsl.dsl.dace.dace_config import DaceConfig
from ndsl.dsl.stencil import CompilationConfig, StencilConfig
Expand Down Expand Up @@ -288,18 +288,19 @@ def test_parallel_savepoint(
multimodal_metric,
xy_indices=True,
):
if MPI.COMM_WORLD.Get_size() % 6 != 0:
mpi_comm = MPIComm()
if mpi_comm.Get_size() % 6 != 0:
layout = (
int(MPI.COMM_WORLD.Get_size() ** 0.5),
int(MPI.COMM_WORLD.Get_size() ** 0.5),
int(mpi_comm.Get_size() ** 0.5),
int(mpi_comm.Get_size() ** 0.5),
)
communicator = get_tile_communicator(MPI.COMM_WORLD, layout)
communicator = get_tile_communicator(mpi_comm, layout)
else:
layout = (
int((MPI.COMM_WORLD.Get_size() // 6) ** 0.5),
int((MPI.COMM_WORLD.Get_size() // 6) ** 0.5),
int((mpi_comm.Get_size() // 6) ** 0.5),
int((mpi_comm.Get_size() // 6) ** 0.5),
)
communicator = get_communicator(MPI.COMM_WORLD, layout)
communicator = get_communicator(mpi_comm, layout)
if case.testobj is None:
pytest.xfail(
f"no translate object available for savepoint {case.savepoint_name}"
Expand Down
2 changes: 1 addition & 1 deletion tests/dsl/test_compilation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
CompilationConfig,
CubedSphereCommunicator,
CubedSpherePartitioner,
NullComm,
RunMode,
TilePartitioner,
NullComm,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/mpi/test_mpi_halo_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
Quantity,
TilePartitioner,
)
from ndsl.comm.mpi import MPIComm
from ndsl.comm._boundary_utils import get_boundary_slice
from ndsl.comm.mpi import MPIComm
from ndsl.constants import (
BOUNDARY_TYPES,
EDGE_BOUNDARY_TYPES,
Expand Down Expand Up @@ -40,7 +40,7 @@ def layout():
if MPI is not None:
size = MPI.COMM_WORLD.Get_size()
ranks_per_tile = size // 6
ranks_per_edge = int(ranks_per_tile**0.5)
ranks_per_edge = int(ranks_per_tile ** 0.5)
return (ranks_per_edge, ranks_per_edge)
else:
return (1, 1)
Expand Down

0 comments on commit 224e6e2

Please sign in to comment.