Skip to content

Commit 22ae74c

Browse files
authored
Add support for patching xarrays when creating datasets (#160)
* Add support for patching xarrays when creating datasets * Update CHANGELOG.md
1 parent 871f262 commit 22ae74c

File tree

6 files changed

+108
-7
lines changed

6 files changed

+108
-7
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Keep it human-readable, your future self will thank you!
1616
- Fix negative variance for constant variables (#148)
1717
- Fix cutout slicing of grid dimension (#145)
1818
- update acumulation (#158)
19+
- Add ability to patch xarrays (#160)
1920

2021
### Added
2122

src/anemoi/datasets/create/functions/sources/xarray/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def check(what, ds, paths, **kwargs):
2929
raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})")
3030

3131

32-
def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs):
32+
def load_one(emoji, context, dates, dataset, *, options={}, flavour=None, patch=None, **kwargs):
3333
import xarray as xr
3434

3535
"""
@@ -54,10 +54,10 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
5454
else:
5555
data = xr.open_dataset(dataset, **options)
5656

57-
fs = XarrayFieldList.from_xarray(data, flavour)
57+
fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
5858

5959
if len(dates) == 0:
60-
return fs.sel(**kwargs)
60+
result = fs.sel(**kwargs)
6161
else:
6262
result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
6363

src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .field import EmptyFieldList
1818
from .flavour import CoordinateGuesser
19+
from .patch import patch_dataset
1920
from .time import Time
2021
from .variable import FilteredVariable
2122
from .variable import Variable
@@ -49,7 +50,11 @@ def __getitem__(self, i):
4950
raise IndexError(k)
5051

5152
@classmethod
52-
def from_xarray(cls, ds, flavour=None):
53+
def from_xarray(cls, ds, *, flavour=None, patch=None):
54+
55+
if patch is not None:
56+
ds = patch_dataset(ds, patch)
57+
5358
variables = []
5459

5560
if isinstance(flavour, str):
@@ -83,6 +88,8 @@ def _skip_attr(v, attr_name):
8388
_skip_attr(variable, "bounds")
8489
_skip_attr(variable, "grid_mapping")
8590

91+
LOG.debug("Xarray data_vars: %s", ds.data_vars)
92+
8693
# Select only geographical variables
8794
for name in ds.data_vars:
8895

@@ -97,13 +104,15 @@ def _skip_attr(v, attr_name):
97104
c = guess.guess(ds[coord], coord)
98105
assert c, f"Could not guess coordinate for {coord}"
99106
if coord not in variable.dims:
107+
LOG.debug("%s: coord=%s (not a dimension): dims=%s", variable, coord, variable.dims)
100108
c.is_dim = False
101109
coordinates.append(c)
102110

103111
grid_coords = sum(1 for c in coordinates if c.is_grid and c.is_dim)
104112
assert grid_coords <= 2
105113

106114
if grid_coords < 2:
115+
LOG.debug("Skipping %s (not 2D): %s", variable, [(c, c.is_grid, c.is_dim) for c in coordinates])
107116
continue
108117

109118
v = Variable(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
11+
import logging
12+
13+
LOG = logging.getLogger(__name__)
14+
15+
16+
def patch_attributes(ds, attributes):
17+
for name, value in attributes.items():
18+
variable = ds[name]
19+
variable.attrs.update(value)
20+
21+
return ds
22+
23+
24+
def patch_coordinates(ds, coordinates):
25+
for name in coordinates:
26+
ds = ds.assign_coords({name: ds[name]})
27+
28+
return ds
29+
30+
31+
PATCHES = {
32+
"attributes": patch_attributes,
33+
"coordinates": patch_coordinates,
34+
}
35+
36+
37+
def patch_dataset(ds, patch):
38+
for what, values in patch.items():
39+
if what not in PATCHES:
40+
raise ValueError(f"Unknown patch type {what!r}")
41+
42+
ds = PATCHES[what](ds, values)
43+
44+
return ds

src/anemoi/datasets/data/dataset.py

+47
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,50 @@ def _compute_constant_fields_from_statistics(self):
512512
result.append(v)
513513

514514
return result
515+
516+
def plot(self, date, variable, member=0, **kwargs):
517+
"""For debugging purposes, plot a field.
518+
519+
Parameters
520+
----------
521+
date : int or datetime.datetime or numpy.datetime64 or str
522+
The date to plot.
523+
variable : int or str
524+
The variable to plot.
525+
member : int, optional
526+
The ensemble member to plot.
527+
528+
**kwargs:
529+
Additional arguments to pass to matplotlib.pyplot.tricontourf
530+
531+
532+
Returns
533+
-------
534+
matplotlib.pyplot.Axes
535+
"""
536+
537+
from anemoi.utils.devtools import plot_values
538+
from earthkit.data.utils.dates import to_datetime
539+
540+
if not isinstance(date, int):
541+
date = np.datetime64(to_datetime(date)).astype(self.dates[0].dtype)
542+
index = np.where(self.dates == date)[0]
543+
if len(index) == 0:
544+
raise ValueError(
545+
f"Date {date} not found in the dataset {self.dates[0]} to {self.dates[-1]} by {self.frequency}"
546+
)
547+
date_index = index[0]
548+
else:
549+
date_index = date
550+
551+
if isinstance(variable, int):
552+
variable_index = variable
553+
else:
554+
if variable not in self.variables:
555+
raise ValueError(f"Unknown variable {variable} (available: {self.variables})")
556+
557+
variable_index = self.name_to_index[variable]
558+
559+
values = self[date_index, variable_index, member]
560+
561+
return plot_values(values, self.latitudes, self.longitudes, **kwargs)

tests/xarray/test_zarr.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_weatherbench():
7373
"levtype": "pl",
7474
}
7575

76-
fs = XarrayFieldList.from_xarray(ds, flavour)
76+
fs = XarrayFieldList.from_xarray(ds, flavour=flavour)
7777

7878
assert_field_list(
7979
fs,
@@ -116,7 +116,7 @@ def test_noaa_replay():
116116
"levtype": "pl",
117117
}
118118

119-
fs = XarrayFieldList.from_xarray(ds, flavour)
119+
fs = XarrayFieldList.from_xarray(ds, flavour=flavour)
120120

121121
assert_field_list(
122122
fs,
@@ -141,7 +141,7 @@ def test_planetary_computer_conus404():
141141
},
142142
}
143143

144-
fs = XarrayFieldList.from_xarray(ds, flavour)
144+
fs = XarrayFieldList.from_xarray(ds, flavour=flavour)
145145

146146
assert_field_list(
147147
fs,

0 commit comments

Comments
 (0)