@@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory):
756
756
assert result is not None
757
757
758
758
759
+ def test_container_map (actx_factory ):
760
+ actx = actx_factory ()
761
+ ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs , bcast_dc_of_dofs = \
762
+ _get_test_containers (actx )
763
+
764
+ # {{{ check
765
+
766
+ def _check_allclose (f , arg1 , arg2 , atol = 2.0e-14 ):
767
+ from arraycontext import NotAnArrayContainerError
768
+ try :
769
+ arg1_iterable = serialize_container (arg1 )
770
+ arg2_iterable = serialize_container (arg2 )
771
+ except NotAnArrayContainerError :
772
+ assert np .linalg .norm (actx .to_numpy (f (arg1 ) - arg2 )) < atol
773
+ else :
774
+ arg1_subarrays = [
775
+ subarray for _ , subarray in arg1_iterable ]
776
+ arg2_subarrays = [
777
+ subarray for _ , subarray in arg2_iterable ]
778
+ for subarray1 , subarray2 in zip (arg1_subarrays , arg2_subarrays ):
779
+ _check_allclose (f , subarray1 , subarray2 )
780
+
781
+ def func (x ):
782
+ return x + 1
783
+
784
+ from arraycontext import rec_map_array_container
785
+ result = rec_map_array_container (func , 1 )
786
+ assert result == 2
787
+
788
+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
789
+ result = rec_map_array_container (func , ary )
790
+ _check_allclose (func , ary , result )
791
+
792
+ from arraycontext import mapped_over_array_containers
793
+
794
+ @mapped_over_array_containers
795
+ def mapped_func (x ):
796
+ return func (x )
797
+
798
+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
799
+ result = mapped_func (ary )
800
+ _check_allclose (func , ary , result )
801
+
802
+ @mapped_over_array_containers (leaf_class = DOFArray )
803
+ def check_leaf (x ):
804
+ assert isinstance (x , DOFArray )
805
+
806
+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
807
+ check_leaf (ary )
808
+
809
+ # }}}
810
+
811
+
759
812
def test_container_multimap (actx_factory ):
760
813
actx = actx_factory ()
761
814
ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs , bcast_dc_of_dofs = \
@@ -764,7 +817,19 @@ def test_container_multimap(actx_factory):
764
817
# {{{ check
765
818
766
819
def _check_allclose (f , arg1 , arg2 , atol = 2.0e-14 ):
767
- assert np .linalg .norm (actx .to_numpy (f (arg1 ) - arg2 )) < atol
820
+ from arraycontext import NotAnArrayContainerError
821
+ try :
822
+ arg1_iterable = serialize_container (arg1 )
823
+ arg2_iterable = serialize_container (arg2 )
824
+ except NotAnArrayContainerError :
825
+ assert np .linalg .norm (actx .to_numpy (f (arg1 ) - arg2 )) < atol
826
+ else :
827
+ arg1_subarrays = [
828
+ subarray for _ , subarray in arg1_iterable ]
829
+ arg2_subarrays = [
830
+ subarray for _ , subarray in arg2_iterable ]
831
+ for subarray1 , subarray2 in zip (arg1_subarrays , arg2_subarrays ):
832
+ _check_allclose (f , subarray1 , subarray2 )
768
833
769
834
def func_all_scalar (x , y ):
770
835
return x + y
@@ -779,17 +844,30 @@ def func_multiple_scalar(a, subary1, b, subary2):
779
844
result = rec_multimap_array_container (func_all_scalar , 1 , 2 )
780
845
assert result == 3
781
846
782
- from functools import partial
783
847
for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
784
848
result = rec_multimap_array_container (func_first_scalar , 1 , ary )
785
- rec_multimap_array_container (
786
- partial (_check_allclose , lambda x : 1 + x ),
787
- ary , result )
849
+ _check_allclose (lambda x : 1 + x , ary , result )
788
850
789
851
result = rec_multimap_array_container (func_multiple_scalar , 2 , ary , 2 , ary )
790
- rec_multimap_array_container (
791
- partial (_check_allclose , lambda x : 4 * x ),
792
- ary , result )
852
+ _check_allclose (lambda x : 4 * x , ary , result )
853
+
854
+ from arraycontext import multimapped_over_array_containers
855
+
856
+ @multimapped_over_array_containers
857
+ def mapped_func (a , subary1 , b , subary2 ):
858
+ return func_multiple_scalar (a , subary1 , b , subary2 )
859
+
860
+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
861
+ result = mapped_func (2 , ary , 2 , ary )
862
+ _check_allclose (lambda x : 4 * x , ary , result )
863
+
864
+ @multimapped_over_array_containers (leaf_class = DOFArray )
865
+ def check_leaf (a , subary1 , b , subary2 ):
866
+ assert isinstance (subary1 , DOFArray )
867
+ assert isinstance (subary2 , DOFArray )
868
+
869
+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
870
+ check_leaf (2 , ary , 2 , ary )
793
871
794
872
with pytest .raises (AssertionError ):
795
873
rec_multimap_array_container (func_multiple_scalar , 2 , ary_dof , 2 , dc_of_dofs )
0 commit comments