Skip to content

Commit 38ce838

Browse files
committed
allow specifying leaf class in recursive map and map-reduce
1 parent 957be2f commit 38ce838

File tree

1 file changed

+30
-17
lines changed

1 file changed

+30
-17
lines changed

arraycontext/container/traversal.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -256,43 +256,51 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
256256

257257
def rec_map_array_container(
258258
f: Callable[[Any], Any],
259-
ary: ArrayOrContainerT) -> ArrayOrContainerT:
259+
ary: ArrayOrContainerT,
260+
leaf_class: Optional[type] = None) -> ArrayOrContainerT:
260261
r"""Applies *f* recursively to an :class:`ArrayContainer`.
261262
262263
For a non-recursive version see :func:`map_array_container`.
263264
264265
:param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
265266
or an instance of a base array type.
266267
"""
267-
return _map_array_container_impl(f, ary, recursive=True)
268+
return _map_array_container_impl(f, ary, leaf_cls=leaf_class, recursive=True)
268269

269270

270271
def mapped_over_array_containers(
271-
f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]:
272+
f: Callable[[Any], Any],
273+
leaf_class: Optional[type] = None) -> Callable[
274+
[ArrayOrContainerT], ArrayOrContainerT]:
272275
"""Decorator around :func:`rec_map_array_container`."""
273-
wrapper = partial(rec_map_array_container, f)
276+
wrapper = partial(rec_map_array_container, f, leaf_class=leaf_class)
274277
update_wrapper(wrapper, f)
275278
return wrapper
276279

277280

278-
def rec_multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
281+
def rec_multimap_array_container(
282+
f: Callable[..., Any],
283+
*args: Any,
284+
leaf_class: Optional[type] = None) -> Any:
279285
r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
280286
281287
For a non-recursive version see :func:`multimap_array_container`.
282288
283289
:param args: all :class:`ArrayContainer` arguments must be of the same
284290
type and with the same structure (same number of components, etc.).
285291
"""
286-
return _multimap_array_container_impl(f, *args, recursive=True)
292+
return _multimap_array_container_impl(
293+
f, *args, leaf_cls=leaf_class, recursive=True)
287294

288295

289296
def multimapped_over_array_containers(
290-
f: Callable[..., Any]) -> Callable[..., Any]:
297+
f: Callable[..., Any],
298+
leaf_class: Optional[type] = None) -> Callable[..., Any]:
291299
"""Decorator around :func:`rec_multimap_array_container`."""
292300
# can't use functools.partial, because its result is insufficiently
293301
# function-y to be used as a method definition.
294302
def wrapper(*args: Any) -> Any:
295-
return rec_multimap_array_container(f, *args)
303+
return rec_multimap_array_container(f, *args, leaf_class=leaf_class)
296304

297305
update_wrapper(wrapper, f)
298306
return wrapper
@@ -401,7 +409,8 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
401409
def rec_map_reduce_array_container(
402410
reduce_func: Callable[[Iterable[Any]], Any],
403411
map_func: Callable[[Any], Any],
404-
ary: ArrayOrContainerT) -> "DeviceArray":
412+
ary: ArrayOrContainerT,
413+
leaf_class: Optional[type] = None) -> "DeviceArray":
405414
"""Perform a map-reduce over array containers recursively.
406415
407416
:param reduce_func: callable used to reduce over the components of *ary*
@@ -440,22 +449,26 @@ def rec_map_reduce_array_container(
440449
or any other such traversal.
441450
"""
442451
def rec(_ary: ArrayOrContainerT) -> ArrayOrContainerT:
443-
try:
444-
iterable = serialize_container(_ary)
445-
except NotAnArrayContainerError:
452+
if type(_ary) is leaf_class:
446453
return map_func(_ary)
447454
else:
448-
return reduce_func([
449-
rec(subary) for _, subary in iterable
450-
])
455+
try:
456+
iterable = serialize_container(_ary)
457+
except NotAnArrayContainerError:
458+
return map_func(_ary)
459+
else:
460+
return reduce_func([
461+
rec(subary) for _, subary in iterable
462+
])
451463

452464
return rec(ary)
453465

454466

455467
def rec_multimap_reduce_array_container(
456468
reduce_func: Callable[[Iterable[Any]], Any],
457469
map_func: Callable[..., Any],
458-
*args: Any) -> "DeviceArray":
470+
*args: Any,
471+
leaf_class: Optional[type] = None) -> "DeviceArray":
459472
r"""Perform a map-reduce over multiple array containers recursively.
460473
461474
:param reduce_func: callable used to reduce over the components of any
@@ -478,7 +491,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
478491

479492
return _multimap_array_container_impl(
480493
map_func, *args,
481-
reduce_func=_reduce_wrapper, leaf_cls=None, recursive=True)
494+
reduce_func=_reduce_wrapper, leaf_cls=leaf_class, recursive=True)
482495

483496
# }}}
484497

0 commit comments

Comments
 (0)