@@ -256,43 +256,51 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
256
256
257
257
def rec_map_array_container (
258
258
f : Callable [[Any ], Any ],
259
- ary : ArrayOrContainerT ) -> ArrayOrContainerT :
259
+ ary : ArrayOrContainerT ,
260
+ leaf_class : Optional [type ] = None ) -> ArrayOrContainerT :
260
261
r"""Applies *f* recursively to an :class:`ArrayContainer`.
261
262
262
263
For a non-recursive version see :func:`map_array_container`.
263
264
264
265
:param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
265
266
or an instance of a base array type.
266
267
"""
267
- return _map_array_container_impl (f , ary , recursive = True )
268
+ return _map_array_container_impl (f , ary , leaf_cls = leaf_class , recursive = True )
268
269
269
270
270
271
def mapped_over_array_containers (
271
- f : Callable [[Any ], Any ]) -> Callable [[ArrayOrContainerT ], ArrayOrContainerT ]:
272
+ f : Callable [[Any ], Any ],
273
+ leaf_class : Optional [type ] = None ) -> Callable [
274
+ [ArrayOrContainerT ], ArrayOrContainerT ]:
272
275
"""Decorator around :func:`rec_map_array_container`."""
273
- wrapper = partial (rec_map_array_container , f )
276
+ wrapper = partial (rec_map_array_container , f , leaf_class = leaf_class )
274
277
update_wrapper (wrapper , f )
275
278
return wrapper
276
279
277
280
278
- def rec_multimap_array_container (f : Callable [..., Any ], * args : Any ) -> Any :
281
+ def rec_multimap_array_container (
282
+ f : Callable [..., Any ],
283
+ * args : Any ,
284
+ leaf_class : Optional [type ] = None ) -> Any :
279
285
r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
280
286
281
287
For a non-recursive version see :func:`multimap_array_container`.
282
288
283
289
:param args: all :class:`ArrayContainer` arguments must be of the same
284
290
type and with the same structure (same number of components, etc.).
285
291
"""
286
- return _multimap_array_container_impl (f , * args , recursive = True )
292
+ return _multimap_array_container_impl (
293
+ f , * args , leaf_cls = leaf_class , recursive = True )
287
294
288
295
289
296
def multimapped_over_array_containers (
290
- f : Callable [..., Any ]) -> Callable [..., Any ]:
297
+ f : Callable [..., Any ],
298
+ leaf_class : Optional [type ] = None ) -> Callable [..., Any ]:
291
299
"""Decorator around :func:`rec_multimap_array_container`."""
292
300
# can't use functools.partial, because its result is insufficiently
293
301
# function-y to be used as a method definition.
294
302
def wrapper (* args : Any ) -> Any :
295
- return rec_multimap_array_container (f , * args )
303
+ return rec_multimap_array_container (f , * args , leaf_class = leaf_class )
296
304
297
305
update_wrapper (wrapper , f )
298
306
return wrapper
@@ -401,7 +409,8 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
401
409
def rec_map_reduce_array_container (
402
410
reduce_func : Callable [[Iterable [Any ]], Any ],
403
411
map_func : Callable [[Any ], Any ],
404
- ary : ArrayOrContainerT ) -> "DeviceArray" :
412
+ ary : ArrayOrContainerT ,
413
+ leaf_class : Optional [type ] = None ) -> "DeviceArray" :
405
414
"""Perform a map-reduce over array containers recursively.
406
415
407
416
:param reduce_func: callable used to reduce over the components of *ary*
@@ -440,22 +449,26 @@ def rec_map_reduce_array_container(
440
449
or any other such traversal.
441
450
"""
442
451
def rec (_ary : ArrayOrContainerT ) -> ArrayOrContainerT :
443
- try :
444
- iterable = serialize_container (_ary )
445
- except NotAnArrayContainerError :
452
+ if type (_ary ) is leaf_class :
446
453
return map_func (_ary )
447
454
else :
448
- return reduce_func ([
449
- rec (subary ) for _ , subary in iterable
450
- ])
455
+ try :
456
+ iterable = serialize_container (_ary )
457
+ except NotAnArrayContainerError :
458
+ return map_func (_ary )
459
+ else :
460
+ return reduce_func ([
461
+ rec (subary ) for _ , subary in iterable
462
+ ])
451
463
452
464
return rec (ary )
453
465
454
466
455
467
def rec_multimap_reduce_array_container (
456
468
reduce_func : Callable [[Iterable [Any ]], Any ],
457
469
map_func : Callable [..., Any ],
458
- * args : Any ) -> "DeviceArray" :
470
+ * args : Any ,
471
+ leaf_class : Optional [type ] = None ) -> "DeviceArray" :
459
472
r"""Perform a map-reduce over multiple array containers recursively.
460
473
461
474
:param reduce_func: callable used to reduce over the components of any
@@ -478,7 +491,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
478
491
479
492
return _multimap_array_container_impl (
480
493
map_func , * args ,
481
- reduce_func = _reduce_wrapper , leaf_cls = None , recursive = True )
494
+ reduce_func = _reduce_wrapper , leaf_cls = leaf_class , recursive = True )
482
495
483
496
# }}}
484
497
0 commit comments