From 224e6e24ecfc09e65bf8f0ec1f9b3f3b0d4c7ed2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 22 Dec 2024 13:50:37 -0500 Subject: [PATCH] Lint + `MPIComm` on testing architecture --- ndsl/comm/communicator.py | 16 ++++++---------- ndsl/stencils/testing/conftest.py | 11 +++++------ ndsl/stencils/testing/test_translate.py | 17 +++++++++-------- tests/dsl/test_compilation_config.py | 2 +- tests/mpi/test_mpi_halo_update.py | 4 ++-- 5 files changed, 23 insertions(+), 27 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index c952c022..1ea4f5a3 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -143,18 +143,14 @@ def all_reduce_per_element( self.comm.Allreduce(input_quantity.data, output_quantity.data, op) def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): - with ( - send_buffer(numpy_module.zeros, sendbuf) as send, - recv_buffer(numpy_module.zeros, recvbuf) as recv, - ): - self.comm.Scatter(send, recv, **kwargs) + with send_buffer(numpy_module.zeros, sendbuf) as send: + with recv_buffer(numpy_module.zeros, recvbuf) as recv: + self.comm.Scatter(send, recv, **kwargs) def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): - with ( - send_buffer(numpy_module.zeros, sendbuf) as send, - recv_buffer(numpy_module.zeros, recvbuf) as recv, - ): - self.comm.Gather(send, recv, **kwargs) + with send_buffer(numpy_module.zeros, sendbuf) as send: + with recv_buffer(numpy_module.zeros, recvbuf) as recv: + self.comm.Gather(send, recv, **kwargs) def scatter( self, diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index af5bb6a6..9810fb4a 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -13,7 +13,7 @@ CubedSphereCommunicator, TileCommunicator, ) -from ndsl.comm.mpi import MPI +from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.namelist import Namelist @@ -308,7 +308,7 @@ def compute_grid_data(grid, namelist, backend, layout, topology_mode): npx=namelist.npx, npy=namelist.npy, npz=namelist.npz, - communicator=get_communicator(MPI.COMM_WORLD, layout, topology_mode), + communicator=get_communicator(MPIComm(), layout, topology_mode), backend=backend, ) @@ -360,13 +360,12 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str): metafunc.config ) # get MPI environment - comm = MPI.COMM_WORLD - mpi_rank = comm.Get_rank() + comm = MPIComm() savepoint_cases = parallel_savepoint_cases( metafunc, data_path, namelist_filename, - mpi_rank, + comm.Get_rank(), backend=backend, comm=comm, ) @@ -376,7 +375,7 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str): def get_communicator(comm, layout, topology_mode): - if (MPI.COMM_WORLD.Get_size() > 1) and (topology_mode == "cubed-sphere"): + if (comm.Get_size() > 1) and (topology_mode == "cubed-sphere"): partitioner = CubedSpherePartitioner(TilePartitioner(layout)) communicator = CubedSphereCommunicator(comm, partitioner) else: diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index db8e6047..64ae5f62 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -8,7 +8,7 @@ import ndsl.dsl.gt4py_utils as gt_utils from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator -from ndsl.comm.mpi import MPI +from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.dsl.stencil import CompilationConfig, StencilConfig @@ -288,18 +288,19 @@ def test_parallel_savepoint( multimodal_metric, xy_indices=True, ): - if MPI.COMM_WORLD.Get_size() % 6 != 0: + mpi_comm = MPIComm() + if mpi_comm.Get_size() % 6 != 0: layout = ( - int(MPI.COMM_WORLD.Get_size() ** 0.5), - int(MPI.COMM_WORLD.Get_size() ** 0.5), + int(mpi_comm.Get_size() ** 0.5), + int(mpi_comm.Get_size() ** 0.5), ) - communicator = get_tile_communicator(MPI.COMM_WORLD, layout) + communicator = get_tile_communicator(mpi_comm, layout) else: layout = ( - int((MPI.COMM_WORLD.Get_size() // 6) ** 0.5), - int((MPI.COMM_WORLD.Get_size() // 6) ** 0.5), + int((mpi_comm.Get_size() // 6) ** 0.5), + int((mpi_comm.Get_size() // 6) ** 0.5), ) - communicator = get_communicator(MPI.COMM_WORLD, layout) + communicator = get_communicator(mpi_comm, layout) if case.testobj is None: pytest.xfail( f"no translate object available for savepoint {case.savepoint_name}" diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 95ca7f74..fa323b06 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -7,9 +7,9 @@ CompilationConfig, CubedSphereCommunicator, CubedSpherePartitioner, + NullComm, RunMode, TilePartitioner, - NullComm, ) diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 1e6aaefc..b6c38e95 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -8,8 +8,8 @@ Quantity, TilePartitioner, ) -from ndsl.comm.mpi import MPIComm from ndsl.comm._boundary_utils import get_boundary_slice +from ndsl.comm.mpi import MPIComm from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -40,7 +40,7 @@ def layout(): if MPI is not None: size = MPI.COMM_WORLD.Get_size() ranks_per_tile = size // 6 - ranks_per_edge = int(ranks_per_tile**0.5) + ranks_per_edge = int(ranks_per_tile ** 0.5) return (ranks_per_edge, ranks_per_edge) else: return (1, 1)