Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement map_over_datasets kwargs #10012

Merged
merged 7 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ v2025.02.0 (unreleased)

New Features
~~~~~~~~~~~~

- Allow kwargs in :py:meth:`DataTree.map_over_datasets` and :py:func:`map_over_datasets` (:issue:`10009`, :pull:`10012`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,7 @@ def map_over_datasets(
self,
func: Callable,
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree | tuple[DataTree, ...]:
"""
Apply a function to every dataset in this subtree, returning a new tree which stores the results.
Expand All @@ -1446,7 +1447,10 @@ def map_over_datasets(

Function will not be applied to any nodes without datasets.
*args : tuple, optional
Positional arguments passed on to `func`.
Positional arguments passed on to `func`. Any DataTree arguments will be
converted to Dataset objects via `.dataset`.
kwargs : dict, optional
Optional keyword arguments passed directly to ``func``.

Returns
-------
Expand All @@ -1459,7 +1463,7 @@ def map_over_datasets(
"""
# TODO this signature means that func has no way to know which node it is being called upon - change?
# TODO fix this typing error
return map_over_datasets(func, self, *args)
return map_over_datasets(func, self, *args, kwargs=kwargs)

def pipe(
self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any
Expand Down
30 changes: 24 additions & 6 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,38 @@


@overload
def map_over_datasets(func: Callable[..., Dataset | None], *args: Any) -> DataTree: ...
def map_over_datasets(
func: Callable[
...,
Dataset | None,
],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree: ...


@overload
def map_over_datasets(
func: Callable[..., tuple[Dataset | None, Dataset | None]], *args: Any
func: Callable[..., tuple[Dataset | None, Dataset | None]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> tuple[DataTree, DataTree]: ...


# add an expect overload for the most common case of two return values
# (python typing does not have a way to match tuple lengths in general)
@overload
def map_over_datasets(
func: Callable[..., tuple[Dataset | None, ...]], *args: Any
func: Callable[..., tuple[Dataset | None, ...]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> tuple[DataTree, ...]: ...


def map_over_datasets(
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]], *args: Any
func: Callable[..., Dataset | None | tuple[Dataset | None, ...]],
*args: Any,
kwargs: Mapping[str, Any] | None = None,
) -> DataTree | tuple[DataTree, ...]:
"""
Applies a function to every dataset in one or more DataTree objects with
Expand Down Expand Up @@ -62,12 +75,14 @@ def map_over_datasets(
func : callable
Function to apply to datasets with signature:

`func(*args: Dataset) -> Union[Dataset, tuple[Dataset, ...]]`.
`func(*args: Dataset, **kwargs) -> Union[Dataset, tuple[Dataset, ...]]`.

(i.e. func must accept at least one Dataset and return at least one Dataset.)
*args : tuple, optional
Positional arguments passed on to `func`. Any DataTree arguments will be
converted to Dataset objects via `.dataset`.
kwargs : dict, optional
Optional keyword arguments passed directly to ``func``.

Returns
-------
Expand All @@ -85,6 +100,9 @@ def map_over_datasets(

from xarray.core.datatree import DataTree

if kwargs is None:
kwargs = {}

# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
Expand All @@ -100,7 +118,7 @@ def map_over_datasets(
node_dataset_args.insert(i, arg)

func_with_error_context = _handle_errors_with_path_context(path)(func)
results = func_with_error_context(*node_dataset_args)
results = func_with_error_context(*node_dataset_args, **kwargs)
out_data_objects[path] = results

num_return_values = _check_all_return_values(out_data_objects)
Expand Down
23 changes: 23 additions & 0 deletions xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ def test_single_tree_arg_plus_arg(self, create_test_datatree):
result_tree = map_over_datasets(lambda x, y: x * y, 10.0, dt)
assert_equal(result_tree, expected)

def test_single_tree_arg_plus_kwarg(self, create_test_datatree):
dt = create_test_datatree()
expected = create_test_datatree(modify=lambda ds: (10.0 * ds))

def multiply_by_kwarg(ds, **kwargs):
ds = ds * kwargs.pop("multiplier")
return ds

result_tree = map_over_datasets(
multiply_by_kwarg, dt, kwargs=dict(multiplier=10.0)
)
assert_equal(result_tree, expected)

def test_multiple_tree_args(self, create_test_datatree):
dt1 = create_test_datatree()
dt2 = create_test_datatree()
Expand Down Expand Up @@ -138,6 +151,16 @@ def multiply(ds, times):
result_tree = dt.map_over_datasets(multiply, 10.0)
assert_equal(result_tree, expected)

def test_tree_method_with_kwarg(self, create_test_datatree):
dt = create_test_datatree()

def multiply(ds, **kwargs):
return kwargs.pop("times") * ds

expected = create_test_datatree(modify=lambda ds: 10.0 * ds)
result_tree = dt.map_over_datasets(multiply, kwargs=dict(times=10.0))
assert_equal(result_tree, expected)

def test_discard_ancestry(self, create_test_datatree):
# Check for datatree GH issue https://github.com/xarray-contrib/datatree/issues/48
dt = create_test_datatree()
Expand Down
Loading