13
13
14
14
15
15
@overload
16
- def map_over_datasets (func : Callable [..., Dataset | None ], * args : Any ) -> DataTree : ...
16
+ def map_over_datasets (
17
+ func : Callable [
18
+ ...,
19
+ Dataset | None ,
20
+ ],
21
+ * args : Any ,
22
+ kwargs : Mapping [str , Any ] | None = None ,
23
+ ) -> DataTree : ...
17
24
18
25
19
26
@overload
20
27
def map_over_datasets (
21
- func : Callable [..., tuple [Dataset | None , Dataset | None ]], * args : Any
28
+ func : Callable [..., tuple [Dataset | None , Dataset | None ]],
29
+ * args : Any ,
30
+ kwargs : Mapping [str , Any ] | None = None ,
22
31
) -> tuple [DataTree , DataTree ]: ...
23
32
24
33
25
34
# add an expect overload for the most common case of two return values
26
35
# (python typing does not have a way to match tuple lengths in general)
27
36
@overload
28
37
def map_over_datasets (
29
- func : Callable [..., tuple [Dataset | None , ...]], * args : Any
38
+ func : Callable [..., tuple [Dataset | None , ...]],
39
+ * args : Any ,
40
+ kwargs : Mapping [str , Any ] | None = None ,
30
41
) -> tuple [DataTree , ...]: ...
31
42
32
43
33
44
def map_over_datasets (
34
- func : Callable [..., Dataset | None | tuple [Dataset | None , ...]], * args : Any
45
+ func : Callable [..., Dataset | None | tuple [Dataset | None , ...]],
46
+ * args : Any ,
47
+ kwargs : Mapping [str , Any ] | None = None ,
35
48
) -> DataTree | tuple [DataTree , ...]:
36
49
"""
37
50
Applies a function to every dataset in one or more DataTree objects with
@@ -62,12 +75,14 @@ def map_over_datasets(
62
75
func : callable
63
76
Function to apply to datasets with signature:
64
77
65
- `func(*args: Dataset) -> Union[Dataset, tuple[Dataset, ...]]`.
78
+ `func(*args: Dataset, **kwargs ) -> Union[Dataset, tuple[Dataset, ...]]`.
66
79
67
80
(i.e. func must accept at least one Dataset and return at least one Dataset.)
68
81
*args : tuple, optional
69
82
Positional arguments passed on to `func`. Any DataTree arguments will be
70
83
converted to Dataset objects via `.dataset`.
84
+ kwargs : dict, optional
85
+ Optional keyword arguments passed directly to ``func``.
71
86
72
87
Returns
73
88
-------
@@ -85,6 +100,9 @@ def map_over_datasets(
85
100
86
101
from xarray .core .datatree import DataTree
87
102
103
+ if kwargs is None :
104
+ kwargs = {}
105
+
88
106
# Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
89
107
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
90
108
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
@@ -100,7 +118,7 @@ def map_over_datasets(
100
118
node_dataset_args .insert (i , arg )
101
119
102
120
func_with_error_context = _handle_errors_with_path_context (path )(func )
103
- results = func_with_error_context (* node_dataset_args )
121
+ results = func_with_error_context (* node_dataset_args , ** kwargs )
104
122
out_data_objects [path ] = results
105
123
106
124
num_return_values = _check_all_return_values (out_data_objects )
0 commit comments