diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 56d9a3d9bed..0ffc6cb527c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,7 +21,8 @@ v2025.02.0 (unreleased) New Features ~~~~~~~~~~~~ - +- Allow kwargs in :py:meth:`DataTree.map_over_datasets` and :py:func:`map_over_datasets` (:issue:`10009`, :pull:`10012`). + By `Kai Mühlbauer `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index ee90cf7477c..00b266e9592 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1429,6 +1429,7 @@ def map_over_datasets( self, func: Callable, *args: Any, + kwargs: Mapping[str, Any] | None = None, ) -> DataTree | tuple[DataTree, ...]: """ Apply a function to every dataset in this subtree, returning a new tree which stores the results. @@ -1446,7 +1447,10 @@ def map_over_datasets( Function will not be applied to any nodes without datasets. *args : tuple, optional - Positional arguments passed on to `func`. + Positional arguments passed on to `func`. Any DataTree arguments will be + converted to Dataset objects via `.dataset`. + kwargs : dict, optional + Optional keyword arguments passed directly to ``func``. Returns ------- @@ -1459,7 +1463,7 @@ def map_over_datasets( """ # TODO this signature means that func has no way to know which node it is being called upon - change? # TODO fix this typing error - return map_over_datasets(func, self, *args) + return map_over_datasets(func, self, *args, kwargs=kwargs) def pipe( self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index fba88ea6ad2..9d78a08c11b 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -13,12 +13,21 @@ @overload -def map_over_datasets(func: Callable[..., Dataset | None], *args: Any) -> DataTree: ... +def map_over_datasets( + func: Callable[ + ..., + Dataset | None, + ], + *args: Any, + kwargs: Mapping[str, Any] | None = None, +) -> DataTree: ... @overload def map_over_datasets( - func: Callable[..., tuple[Dataset | None, Dataset | None]], *args: Any + func: Callable[..., tuple[Dataset | None, Dataset | None]], + *args: Any, + kwargs: Mapping[str, Any] | None = None, ) -> tuple[DataTree, DataTree]: ... @@ -26,12 +35,16 @@ def map_over_datasets( # (python typing does not have a way to match tuple lengths in general) @overload def map_over_datasets( - func: Callable[..., tuple[Dataset | None, ...]], *args: Any + func: Callable[..., tuple[Dataset | None, ...]], + *args: Any, + kwargs: Mapping[str, Any] | None = None, ) -> tuple[DataTree, ...]: ... def map_over_datasets( - func: Callable[..., Dataset | None | tuple[Dataset | None, ...]], *args: Any + func: Callable[..., Dataset | None | tuple[Dataset | None, ...]], + *args: Any, + kwargs: Mapping[str, Any] | None = None, ) -> DataTree | tuple[DataTree, ...]: """ Applies a function to every dataset in one or more DataTree objects with @@ -62,12 +75,14 @@ def map_over_datasets( func : callable Function to apply to datasets with signature: - `func(*args: Dataset) -> Union[Dataset, tuple[Dataset, ...]]`. + `func(*args: Dataset, **kwargs) -> Union[Dataset, tuple[Dataset, ...]]`. (i.e. func must accept at least one Dataset and return at least one Dataset.) *args : tuple, optional Positional arguments passed on to `func`. Any DataTree arguments will be converted to Dataset objects via `.dataset`. + kwargs : dict, optional + Optional keyword arguments passed directly to ``func``. Returns ------- @@ -85,6 +100,9 @@ def map_over_datasets( from xarray.core.datatree import DataTree + if kwargs is None: + kwargs = {} + # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees # We don't know which arguments are DataTrees so we zip all arguments together as iterables # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return @@ -100,7 +118,7 @@ def map_over_datasets( node_dataset_args.insert(i, arg) func_with_error_context = _handle_errors_with_path_context(path)(func) - results = func_with_error_context(*node_dataset_args) + results = func_with_error_context(*node_dataset_args, **kwargs) out_data_objects[path] = results num_return_values = _check_all_return_values(out_data_objects) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index ec91a3c03e6..d77b6e11263 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -51,6 +51,19 @@ def test_single_tree_arg_plus_arg(self, create_test_datatree): result_tree = map_over_datasets(lambda x, y: x * y, 10.0, dt) assert_equal(result_tree, expected) + def test_single_tree_arg_plus_kwarg(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(modify=lambda ds: (10.0 * ds)) + + def multiply_by_kwarg(ds, **kwargs): + ds = ds * kwargs.pop("multiplier") + return ds + + result_tree = map_over_datasets( + multiply_by_kwarg, dt, kwargs=dict(multiplier=10.0) + ) + assert_equal(result_tree, expected) + def test_multiple_tree_args(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() @@ -138,6 +151,16 @@ def multiply(ds, times): result_tree = dt.map_over_datasets(multiply, 10.0) assert_equal(result_tree, expected) + def test_tree_method_with_kwarg(self, create_test_datatree): + dt = create_test_datatree() + + def multiply(ds, **kwargs): + return kwargs.pop("times") * ds + + expected = create_test_datatree(modify=lambda ds: 10.0 * ds) + result_tree = dt.map_over_datasets(multiply, kwargs=dict(times=10.0)) + assert_equal(result_tree, expected) + def test_discard_ancestry(self, create_test_datatree): # Check for datatree GH issue https://github.com/xarray-contrib/datatree/issues/48 dt = create_test_datatree()