From 0e8089eed9c909bf91cc1a117ecf12cf6cfe7397 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 22 Dec 2024 10:45:06 -0500 Subject: [PATCH] Update utest --- tests/mpi/test_mpi_all_reduce_sum.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 858a7f94..4a15ad53 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -7,6 +7,7 @@ Quantity, TilePartitioner, ) +from ndsl.comm.comm_abc import ReductionOperator from ndsl.dsl.typing import Float from tests.mpi.mpi_comm import MPI @@ -48,10 +49,7 @@ def communicator(cube_partitioner): @pytest.mark.skipif( MPI is None, reason="mpi4py is not available or pytest was not run in parallel" ) -def test_all_reduce_sum( - communicator, -): - +def test_all_reduce(communicator): backends = ["dace:cpu", "gt:cpu_kfirst", "numpy"] for backend in backends: @@ -84,15 +82,15 @@ def test_all_reduce_sum( gt4py_backend=backend, ) - global_sum_q = communicator.all_reduce_sum(testQuantity_1D) + global_sum_q = communicator.all_reduce(testQuantity_1D, ReductionOperator.SUM) assert global_sum_q.metadata == testQuantity_1D.metadata assert (global_sum_q.data == (testQuantity_1D.data * communicator.size)).all() - global_sum_q = communicator.all_reduce_sum(testQuantity_2D) + global_sum_q = communicator.all_reduce(testQuantity_2D, ReductionOperator.SUM) assert global_sum_q.metadata == testQuantity_2D.metadata assert (global_sum_q.data == (testQuantity_2D.data * communicator.size)).all() - global_sum_q = communicator.all_reduce_sum(testQuantity_3D) + global_sum_q = communicator.all_reduce(testQuantity_3D, ReductionOperator.SUM) assert global_sum_q.metadata == testQuantity_3D.metadata assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() @@ -125,19 +123,25 @@ def test_all_reduce_sum( units="Some 3D unit", gt4py_backend=backend, ) - communicator.all_reduce_sum(testQuantity_1D, testQuantity_1D_out) + communicator.all_reduce( + testQuantity_1D, ReductionOperator.SUM, testQuantity_1D_out + ) assert testQuantity_1D_out.metadata == testQuantity_1D.metadata assert ( testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size) ).all() - communicator.all_reduce_sum(testQuantity_2D, testQuantity_2D_out) + communicator.all_reduce( + testQuantity_2D, ReductionOperator.SUM, testQuantity_2D_out + ) assert testQuantity_2D_out.metadata == testQuantity_2D.metadata assert ( testQuantity_2D_out.data == (testQuantity_2D.data * communicator.size) ).all() - communicator.all_reduce_sum(testQuantity_3D, testQuantity_3D_out) + communicator.all_reduce( + testQuantity_3D, ReductionOperator.SUM, testQuantity_3D_out + ) assert testQuantity_3D_out.metadata == testQuantity_3D.metadata assert ( testQuantity_3D_out.data == (testQuantity_3D.data * communicator.size)