Skip to content

Commit b395fa8

Browse files
committed
add leaf_class to decorators
1 parent 3974d67 commit b395fa8

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

arraycontext/container/traversal.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,22 @@ def rec_map_array_container(
269269

270270

271271
def mapped_over_array_containers(
272-
f: Callable[[Any], Any]) -> Callable[[ArrayOrContainerT], ArrayOrContainerT]:
272+
f: Optional[Callable[[Any], Any]] = None,
273+
leaf_class: Optional[type] = None) -> Union[
274+
Callable[[ArrayOrContainerT], ArrayOrContainerT],
275+
Callable[
276+
[Callable[[Any], Any]],
277+
Callable[[ArrayOrContainerT], ArrayOrContainerT]]]:
273278
"""Decorator around :func:`rec_map_array_container`."""
274-
wrapper = partial(rec_map_array_container, f)
275-
update_wrapper(wrapper, f)
276-
return wrapper
279+
def decorator(g: Callable[[Any], Any]) -> Callable[
280+
[ArrayOrContainerT], ArrayOrContainerT]:
281+
wrapper = partial(rec_map_array_container, g, leaf_class=leaf_class)
282+
update_wrapper(wrapper, g)
283+
return wrapper
284+
if f is not None:
285+
return decorator(f)
286+
else:
287+
return decorator
277288

278289

279290
def rec_multimap_array_container(
@@ -292,15 +303,23 @@ def rec_multimap_array_container(
292303

293304

294305
def multimapped_over_array_containers(
295-
f: Callable[..., Any]) -> Callable[..., Any]:
306+
f: Optional[Callable[..., Any]] = None,
307+
leaf_class: Optional[type] = None) -> Union[
308+
Callable[..., Any],
309+
Callable[[Callable[..., Any]], Callable[..., Any]]]:
296310
"""Decorator around :func:`rec_multimap_array_container`."""
297-
# can't use functools.partial, because its result is insufficiently
298-
# function-y to be used as a method definition.
299-
def wrapper(*args: Any) -> Any:
300-
return rec_multimap_array_container(f, *args)
311+
def decorator(g: Callable[..., Any]) -> Callable[..., Any]:
312+
# can't use functools.partial, because its result is insufficiently
313+
# function-y to be used as a method definition.
314+
def wrapper(*args: Any) -> Any:
315+
return rec_multimap_array_container(g, *args, leaf_class=leaf_class)
316+
update_wrapper(wrapper, g)
317+
return wrapper
318+
if f is not None:
319+
return decorator(f)
320+
else:
321+
return decorator
301322

302-
update_wrapper(wrapper, f)
303-
return wrapper
304323

305324
# }}}
306325

0 commit comments

Comments
 (0)