6767
6868import numpy as np
6969
70- from arraycontext .context import ArrayContext
70+ from arraycontext .context import ArrayContext , DeviceArray
7171from arraycontext .container import (
7272 ContainerT , ArrayOrContainerT , NotAnArrayContainerError ,
7373 serialize_container , deserialize_container )
@@ -355,7 +355,7 @@ def rec(keys: Tuple[Union[str, int], ...],
355355def map_reduce_array_container (
356356 reduce_func : Callable [[Iterable [Any ]], Any ],
357357 map_func : Callable [[Any ], Any ],
358- ary : ArrayOrContainerT ) -> Any :
358+ ary : ArrayOrContainerT ) -> "DeviceArray" :
359359 """Perform a map-reduce over array containers.
360360
361361 :param reduce_func: callable used to reduce over the components of *ary*
@@ -378,7 +378,7 @@ def map_reduce_array_container(
378378def multimap_reduce_array_container (
379379 reduce_func : Callable [[Iterable [Any ]], Any ],
380380 map_func : Callable [..., Any ],
381- * args : Any ) -> Any :
381+ * args : Any ) -> "DeviceArray" :
382382 r"""Perform a map-reduce over multiple array containers.
383383
384384 :param reduce_func: callable used to reduce over the components of any
@@ -401,7 +401,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
401401def rec_map_reduce_array_container (
402402 reduce_func : Callable [[Iterable [Any ]], Any ],
403403 map_func : Callable [[Any ], Any ],
404- ary : ArrayOrContainerT ) -> Any :
404+ ary : ArrayOrContainerT ) -> "DeviceArray" :
405405 """Perform a map-reduce over array containers recursively.
406406
407407 :param reduce_func: callable used to reduce over the components of *ary*
@@ -455,7 +455,7 @@ def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
455455def rec_multimap_reduce_array_container (
456456 reduce_func : Callable [[Iterable [Any ]], Any ],
457457 map_func : Callable [..., Any ],
458- * args : Any ) -> Any :
458+ * args : Any ) -> "DeviceArray" :
459459 r"""Perform a map-reduce over multiple array containers recursively.
460460
461461 :param reduce_func: callable used to reduce over the components of any
0 commit comments