Skip to content

Commit 1189240

Browse files
implement map_over_datasets kwargs (#10012)
* add kwargs to map_over_datasets (similar to apply_ufunc), add test. * try to fix typing * improve typing and simplify kwargs-handling per review suggestions * apply changes to DataTree.map_over_datasets * add whats-new.rst entry * Update xarray/core/datatree_mapping.py Co-authored-by: Mathias Hauser <[email protected]> * add suggestions from review. --------- Co-authored-by: Mathias Hauser <[email protected]>
1 parent c8f7dc6 commit 1189240

File tree

4 files changed

+55
-9
lines changed

4 files changed

+55
-9
lines changed

doc/whats-new.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ v2025.02.0 (unreleased)
2121

2222
New Features
2323
~~~~~~~~~~~~
24-
24+
- Allow kwargs in :py:meth:`DataTree.map_over_datasets` and :py:func:`map_over_datasets` (:issue:`10009`, :pull:`10012`).
25+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
2526

2627
Breaking changes
2728
~~~~~~~~~~~~~~~~

xarray/core/datatree.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1429,6 +1429,7 @@ def map_over_datasets(
14291429
self,
14301430
func: Callable,
14311431
*args: Any,
1432+
kwargs: Mapping[str, Any] | None = None,
14321433
) -> DataTree | tuple[DataTree, ...]:
14331434
"""
14341435
Apply a function to every dataset in this subtree, returning a new tree which stores the results.
@@ -1446,7 +1447,10 @@ def map_over_datasets(
14461447
14471448
Function will not be applied to any nodes without datasets.
14481449
*args : tuple, optional
1449-
Positional arguments passed on to `func`.
1450+
Positional arguments passed on to `func`. Any DataTree arguments will be
1451+
converted to Dataset objects via `.dataset`.
1452+
kwargs : dict, optional
1453+
Optional keyword arguments passed directly to ``func``.
14501454
14511455
Returns
14521456
-------
@@ -1459,7 +1463,7 @@ def map_over_datasets(
14591463
"""
14601464
# TODO this signature means that func has no way to know which node it is being called upon - change?
14611465
# TODO fix this typing error
1462-
return map_over_datasets(func, self, *args)
1466+
return map_over_datasets(func, self, *args, kwargs=kwargs)
14631467

14641468
def pipe(
14651469
self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any

xarray/core/datatree_mapping.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,38 @@
1313

1414

1515
@overload
16-
def map_over_datasets(func: Callable[..., Dataset | None], *args: Any) -> DataTree: ...
16+
def map_over_datasets(
17+
func: Callable[
18+
...,
19+
Dataset | None,
20+
],
21+
*args: Any,
22+
kwargs: Mapping[str, Any] | None = None,
23+
) -> DataTree: ...
1724

1825

1926
@overload
2027
def map_over_datasets(
21-
func: Callable[..., tuple[Dataset | None, Dataset | None]], *args: Any
28+
func: Callable[..., tuple[Dataset | None, Dataset | None]],
29+
*args: Any,
30+
kwargs: Mapping[str, Any] | None = None,
2231
) -> tuple[DataTree, DataTree]: ...
2332

2433

2534
# add an expect overload for the most common case of two return values
2635
# (python typing does not have a way to match tuple lengths in general)
2736
@overload
2837
def map_over_datasets(
29-
func: Callable[..., tuple[Dataset | None, ...]], *args: Any
38+
func: Callable[..., tuple[Dataset | None, ...]],
39+
*args: Any,
40+
kwargs: Mapping[str, Any] | None = None,
3041
) -> tuple[DataTree, ...]: ...
3142

3243

3344
def map_over_datasets(
34-
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]], *args: Any
45+
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]],
46+
*args: Any,
47+
kwargs: Mapping[str, Any] | None = None,
3548
) -> DataTree | tuple[DataTree, ...]:
3649
"""
3750
Applies a function to every dataset in one or more DataTree objects with
@@ -62,12 +75,14 @@ def map_over_datasets(
6275
func : callable
6376
Function to apply to datasets with signature:
6477
65-
`func(*args: Dataset) -> Union[Dataset, tuple[Dataset, ...]]`.
78+
`func(*args: Dataset, **kwargs) -> Union[Dataset, tuple[Dataset, ...]]`.
6679
6780
(i.e. func must accept at least one Dataset and return at least one Dataset.)
6881
*args : tuple, optional
6982
Positional arguments passed on to `func`. Any DataTree arguments will be
7083
converted to Dataset objects via `.dataset`.
84+
kwargs : dict, optional
85+
Optional keyword arguments passed directly to ``func``.
7186
7287
Returns
7388
-------
@@ -85,6 +100,9 @@ def map_over_datasets(
85100

86101
from xarray.core.datatree import DataTree
87102

103+
if kwargs is None:
104+
kwargs = {}
105+
88106
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
89107
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
90108
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
@@ -100,7 +118,7 @@ def map_over_datasets(
100118
node_dataset_args.insert(i, arg)
101119

102120
func_with_error_context = _handle_errors_with_path_context(path)(func)
103-
results = func_with_error_context(*node_dataset_args)
121+
results = func_with_error_context(*node_dataset_args, **kwargs)
104122
out_data_objects[path] = results
105123

106124
num_return_values = _check_all_return_values(out_data_objects)

xarray/tests/test_datatree_mapping.py

+23
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ def test_single_tree_arg_plus_arg(self, create_test_datatree):
5151
result_tree = map_over_datasets(lambda x, y: x * y, 10.0, dt)
5252
assert_equal(result_tree, expected)
5353

54+
def test_single_tree_arg_plus_kwarg(self, create_test_datatree):
55+
dt = create_test_datatree()
56+
expected = create_test_datatree(modify=lambda ds: (10.0 * ds))
57+
58+
def multiply_by_kwarg(ds, **kwargs):
59+
ds = ds * kwargs.pop("multiplier")
60+
return ds
61+
62+
result_tree = map_over_datasets(
63+
multiply_by_kwarg, dt, kwargs=dict(multiplier=10.0)
64+
)
65+
assert_equal(result_tree, expected)
66+
5467
def test_multiple_tree_args(self, create_test_datatree):
5568
dt1 = create_test_datatree()
5669
dt2 = create_test_datatree()
@@ -138,6 +151,16 @@ def multiply(ds, times):
138151
result_tree = dt.map_over_datasets(multiply, 10.0)
139152
assert_equal(result_tree, expected)
140153

154+
def test_tree_method_with_kwarg(self, create_test_datatree):
155+
dt = create_test_datatree()
156+
157+
def multiply(ds, **kwargs):
158+
return kwargs.pop("times") * ds
159+
160+
expected = create_test_datatree(modify=lambda ds: 10.0 * ds)
161+
result_tree = dt.map_over_datasets(multiply, kwargs=dict(times=10.0))
162+
assert_equal(result_tree, expected)
163+
141164
def test_discard_ancestry(self, create_test_datatree):
142165
# Check for datatree GH issue https://github.com/xarray-contrib/datatree/issues/48
143166
dt = create_test_datatree()

0 commit comments

Comments
 (0)