2
2
3
3
import sys
4
4
from collections .abc import Callable , Mapping
5
+ from functools import partial
5
6
from typing import TYPE_CHECKING , Any , cast , overload
6
7
7
8
from xarray .core .dataset import Dataset
@@ -31,7 +32,9 @@ def map_over_datasets(
31
32
32
33
33
34
def map_over_datasets (
34
- func : Callable [..., Dataset | None | tuple [Dataset | None , ...]], * args : Any
35
+ func : Callable [..., Dataset | None | tuple [Dataset | None , ...]],
36
+ * args : Any ,
37
+ kwargs : Mapping | None = None ,
35
38
) -> DataTree | tuple [DataTree , ...]:
36
39
"""
37
40
Applies a function to every dataset in one or more DataTree objects with
@@ -68,6 +71,8 @@ def map_over_datasets(
68
71
*args : tuple, optional
69
72
Positional arguments passed on to `func`. Any DataTree arguments will be
70
73
converted to Dataset objects via `.dataset`.
74
+ kwargs : dict, optional
75
+ Optional keyword arguments passed directly on to call ``func``.
71
76
72
77
Returns
73
78
-------
@@ -85,6 +90,12 @@ def map_over_datasets(
85
90
86
91
from xarray .core .datatree import DataTree
87
92
93
+ if kwargs is None :
94
+ kwargs = {}
95
+
96
+ if kwargs :
97
+ func = partial (func , ** kwargs )
98
+
88
99
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
89
100
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
90
101
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
0 commit comments