Skip to content

Commit 3c07cce

Browse files
committed
add kwargs to map_over_datasets (similar to apply_ufunc), add test.
1 parent c252152 commit 3c07cce

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

xarray/core/datatree_mapping.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import sys
44
from collections.abc import Callable, Mapping
5+
from functools import partial
56
from typing import TYPE_CHECKING, Any, cast, overload
67

78
from xarray.core.dataset import Dataset
@@ -31,7 +32,9 @@ def map_over_datasets(
3132

3233

3334
def map_over_datasets(
34-
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]], *args: Any
35+
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]],
36+
*args: Any,
37+
kwargs: Mapping | None = None,
3538
) -> DataTree | tuple[DataTree, ...]:
3639
"""
3740
Applies a function to every dataset in one or more DataTree objects with
@@ -68,6 +71,8 @@ def map_over_datasets(
6871
*args : tuple, optional
6972
Positional arguments passed on to `func`. Any DataTree arguments will be
7073
converted to Dataset objects via `.dataset`.
74+
kwargs : dict, optional
75+
Optional keyword arguments passed directly on to call ``func``.
7176
7277
Returns
7378
-------
@@ -85,6 +90,12 @@ def map_over_datasets(
8590

8691
from xarray.core.datatree import DataTree
8792

93+
if kwargs is None:
94+
kwargs = {}
95+
96+
if kwargs:
97+
func = partial(func, **kwargs)
98+
8899
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
89100
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
90101
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return

xarray/tests/test_datatree_mapping.py

+13
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ def test_single_tree_arg_plus_arg(self, create_test_datatree):
5151
result_tree = map_over_datasets(lambda x, y: x * y, 10.0, dt)
5252
assert_equal(result_tree, expected)
5353

54+
def test_single_tree_arg_plus_kwarg(self, create_test_datatree):
55+
dt = create_test_datatree()
56+
expected = create_test_datatree(modify=lambda ds: (10.0 * ds))
57+
58+
def multiply_by_kwarg(ds, **kwargs):
59+
ds = ds * kwargs.pop("multiplier")
60+
return ds
61+
62+
result_tree = map_over_datasets(
63+
multiply_by_kwarg, dt, kwargs=dict(multiplier=10.0)
64+
)
65+
assert_equal(result_tree, expected)
66+
5467
def test_multiple_tree_args(self, create_test_datatree):
5568
dt1 = create_test_datatree()
5669
dt2 = create_test_datatree()

0 commit comments

Comments
 (0)