Skip to content

Commit

Permalink
Update utest
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Dec 22, 2024
1 parent cc620c6 commit 0e8089e
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions tests/mpi/test_mpi_all_reduce_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Quantity,
TilePartitioner,
)
from ndsl.comm.comm_abc import ReductionOperator
from ndsl.dsl.typing import Float
from tests.mpi.mpi_comm import MPI

Expand Down Expand Up @@ -48,10 +49,7 @@ def communicator(cube_partitioner):
@pytest.mark.skipif(
MPI is None, reason="mpi4py is not available or pytest was not run in parallel"
)
def test_all_reduce_sum(
communicator,
):

def test_all_reduce(communicator):
backends = ["dace:cpu", "gt:cpu_kfirst", "numpy"]

for backend in backends:
Expand Down Expand Up @@ -84,15 +82,15 @@ def test_all_reduce_sum(
gt4py_backend=backend,
)

global_sum_q = communicator.all_reduce_sum(testQuantity_1D)
global_sum_q = communicator.all_reduce(testQuantity_1D, ReductionOperator.SUM)
assert global_sum_q.metadata == testQuantity_1D.metadata
assert (global_sum_q.data == (testQuantity_1D.data * communicator.size)).all()

global_sum_q = communicator.all_reduce_sum(testQuantity_2D)
global_sum_q = communicator.all_reduce(testQuantity_2D, ReductionOperator.SUM)
assert global_sum_q.metadata == testQuantity_2D.metadata
assert (global_sum_q.data == (testQuantity_2D.data * communicator.size)).all()

global_sum_q = communicator.all_reduce_sum(testQuantity_3D)
global_sum_q = communicator.all_reduce(testQuantity_3D, ReductionOperator.SUM)
assert global_sum_q.metadata == testQuantity_3D.metadata
assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all()

Expand Down Expand Up @@ -125,19 +123,25 @@ def test_all_reduce_sum(
units="Some 3D unit",
gt4py_backend=backend,
)
communicator.all_reduce_sum(testQuantity_1D, testQuantity_1D_out)
communicator.all_reduce(
testQuantity_1D, ReductionOperator.SUM, testQuantity_1D_out
)
assert testQuantity_1D_out.metadata == testQuantity_1D.metadata
assert (
testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size)
).all()

communicator.all_reduce_sum(testQuantity_2D, testQuantity_2D_out)
communicator.all_reduce(
testQuantity_2D, ReductionOperator.SUM, testQuantity_2D_out
)
assert testQuantity_2D_out.metadata == testQuantity_2D.metadata
assert (
testQuantity_2D_out.data == (testQuantity_2D.data * communicator.size)
).all()

communicator.all_reduce_sum(testQuantity_3D, testQuantity_3D_out)
communicator.all_reduce(
testQuantity_3D, ReductionOperator.SUM, testQuantity_3D_out
)
assert testQuantity_3D_out.metadata == testQuantity_3D.metadata
assert (
testQuantity_3D_out.data == (testQuantity_3D.data * communicator.size)
Expand Down

0 comments on commit 0e8089e

Please sign in to comment.