Skip to content

Commit 442a119

Browse files
committed
add tests for [multi]mapped_over_array_containers
1 parent b395fa8 commit 442a119

File tree

1 file changed

+86
-8
lines changed

1 file changed

+86
-8
lines changed

test/test_arraycontext.py

+86-8
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory):
756756
assert result is not None
757757

758758

759+
def test_container_map(actx_factory):
760+
actx = actx_factory()
761+
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
762+
_get_test_containers(actx)
763+
764+
# {{{ check
765+
766+
def _check_allclose(f, arg1, arg2, atol=2.0e-14):
767+
from arraycontext import NotAnArrayContainerError
768+
try:
769+
arg1_iterable = serialize_container(arg1)
770+
arg2_iterable = serialize_container(arg2)
771+
except NotAnArrayContainerError:
772+
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
773+
else:
774+
arg1_subarrays = [
775+
subarray for _, subarray in arg1_iterable]
776+
arg2_subarrays = [
777+
subarray for _, subarray in arg2_iterable]
778+
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
779+
_check_allclose(f, subarray1, subarray2)
780+
781+
def func(x):
782+
return x + 1
783+
784+
from arraycontext import rec_map_array_container
785+
result = rec_map_array_container(func, 1)
786+
assert result == 2
787+
788+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
789+
result = rec_map_array_container(func, ary)
790+
_check_allclose(func, ary, result)
791+
792+
from arraycontext import mapped_over_array_containers
793+
794+
@mapped_over_array_containers
795+
def mapped_func(x):
796+
return func(x)
797+
798+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
799+
result = mapped_func(ary)
800+
_check_allclose(func, ary, result)
801+
802+
@mapped_over_array_containers(leaf_class=DOFArray)
803+
def check_leaf(x):
804+
assert isinstance(x, DOFArray)
805+
806+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
807+
check_leaf(ary)
808+
809+
# }}}
810+
811+
759812
def test_container_multimap(actx_factory):
760813
actx = actx_factory()
761814
ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs, bcast_dc_of_dofs = \
@@ -764,7 +817,19 @@ def test_container_multimap(actx_factory):
764817
# {{{ check
765818

766819
def _check_allclose(f, arg1, arg2, atol=2.0e-14):
767-
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
820+
from arraycontext import NotAnArrayContainerError
821+
try:
822+
arg1_iterable = serialize_container(arg1)
823+
arg2_iterable = serialize_container(arg2)
824+
except NotAnArrayContainerError:
825+
assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol
826+
else:
827+
arg1_subarrays = [
828+
subarray for _, subarray in arg1_iterable]
829+
arg2_subarrays = [
830+
subarray for _, subarray in arg2_iterable]
831+
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays):
832+
_check_allclose(f, subarray1, subarray2)
768833

769834
def func_all_scalar(x, y):
770835
return x + y
@@ -779,17 +844,30 @@ def func_multiple_scalar(a, subary1, b, subary2):
779844
result = rec_multimap_array_container(func_all_scalar, 1, 2)
780845
assert result == 3
781846

782-
from functools import partial
783847
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
784848
result = rec_multimap_array_container(func_first_scalar, 1, ary)
785-
rec_multimap_array_container(
786-
partial(_check_allclose, lambda x: 1 + x),
787-
ary, result)
849+
_check_allclose(lambda x: 1 + x, ary, result)
788850

789851
result = rec_multimap_array_container(func_multiple_scalar, 2, ary, 2, ary)
790-
rec_multimap_array_container(
791-
partial(_check_allclose, lambda x: 4 * x),
792-
ary, result)
852+
_check_allclose(lambda x: 4 * x, ary, result)
853+
854+
from arraycontext import multimapped_over_array_containers
855+
856+
@multimapped_over_array_containers
857+
def mapped_func(a, subary1, b, subary2):
858+
return func_multiple_scalar(a, subary1, b, subary2)
859+
860+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
861+
result = mapped_func(2, ary, 2, ary)
862+
_check_allclose(lambda x: 4 * x, ary, result)
863+
864+
@multimapped_over_array_containers(leaf_class=DOFArray)
865+
def check_leaf(a, subary1, b, subary2):
866+
assert isinstance(subary1, DOFArray)
867+
assert isinstance(subary2, DOFArray)
868+
869+
for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]:
870+
check_leaf(2, ary, 2, ary)
793871

794872
with pytest.raises(AssertionError):
795873
rec_multimap_array_container(func_multiple_scalar, 2, ary_dof, 2, dc_of_dofs)

0 commit comments

Comments
 (0)