Skip to content

Commit bda0dda

Browse files
authored
Merge branch 'develop' into feature/augment
2 parents 019546f + 22ae74c commit bda0dda

File tree

11 files changed

+158
-26
lines changed

11 files changed

+158
-26
lines changed

CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ Keep it human-readable, your future self will thank you!
1717
- Fix cutout slicing of grid dimension (#145)
1818
- Use cKDTree instead of KDTree
1919
- Implement 'complement' feature
20-
- Update accumulations (#158)
21-
20+
- Add ability to patch xarrays (#160)
2221

2322
### Added
2423

src/anemoi/datasets/create/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,14 @@ def check_shape(cube, dates, dates_in_data):
622622

623623
check_shape(cube, dates, dates_in_data)
624624

625-
def check_dates_in_data(lst, lst2):
626-
lst2 = [np.datetime64(_) for _ in lst2]
627-
lst = [np.datetime64(_) for _ in lst]
628-
assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2)
625+
def check_dates_in_data(dates_in_data, requested_dates):
626+
requested_dates = [np.datetime64(_) for _ in requested_dates]
627+
dates_in_data = [np.datetime64(_) for _ in dates_in_data]
628+
assert dates_in_data == requested_dates, (
629+
"Dates in data are not the requested ones:",
630+
dates_in_data,
631+
requested_dates,
632+
)
629633

630634
check_dates_in_data(dates_in_data, dates)
631635

src/anemoi/datasets/create/functions/sources/xarray/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def check(what, ds, paths, **kwargs):
2929
raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})")
3030

3131

32-
def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs):
32+
def load_one(emoji, context, dates, dataset, *, options={}, flavour=None, patch=None, **kwargs):
3333
import xarray as xr
3434

3535
"""
@@ -54,10 +54,10 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
5454
else:
5555
data = xr.open_dataset(dataset, **options)
5656

57-
fs = XarrayFieldList.from_xarray(data, flavour)
57+
fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)
5858

5959
if len(dates) == 0:
60-
return fs.sel(**kwargs)
60+
result = fs.sel(**kwargs)
6161
else:
6262
result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])
6363

src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .field import EmptyFieldList
1818
from .flavour import CoordinateGuesser
19+
from .patch import patch_dataset
1920
from .time import Time
2021
from .variable import FilteredVariable
2122
from .variable import Variable
@@ -49,7 +50,11 @@ def __getitem__(self, i):
4950
raise IndexError(k)
5051

5152
@classmethod
52-
def from_xarray(cls, ds, flavour=None):
53+
def from_xarray(cls, ds, *, flavour=None, patch=None):
54+
55+
if patch is not None:
56+
ds = patch_dataset(ds, patch)
57+
5358
variables = []
5459

5560
if isinstance(flavour, str):
@@ -83,6 +88,8 @@ def _skip_attr(v, attr_name):
8388
_skip_attr(variable, "bounds")
8489
_skip_attr(variable, "grid_mapping")
8590

91+
LOG.debug("Xarray data_vars: %s", ds.data_vars)
92+
8693
# Select only geographical variables
8794
for name in ds.data_vars:
8895

@@ -97,13 +104,15 @@ def _skip_attr(v, attr_name):
97104
c = guess.guess(ds[coord], coord)
98105
assert c, f"Could not guess coordinate for {coord}"
99106
if coord not in variable.dims:
107+
LOG.debug("%s: coord=%s (not a dimension): dims=%s", variable, coord, variable.dims)
100108
c.is_dim = False
101109
coordinates.append(c)
102110

103111
grid_coords = sum(1 for c in coordinates if c.is_grid and c.is_dim)
104112
assert grid_coords <= 2
105113

106114
if grid_coords < 2:
115+
LOG.debug("Skipping %s (not 2D): %s", variable, [(c, c.is_grid, c.is_dim) for c in coordinates])
107116
continue
108117

109118
v = Variable(

src/anemoi/datasets/create/functions/sources/xarray/metadata.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class _MDMapping:
2424
def __init__(self, variable):
2525
self.variable = variable
2626
self.time = variable.time
27+
# Aliases
2728
self.mapping = dict(param="variable")
2829
for c in variable.coordinates:
2930
for v in c.mars_names:
@@ -34,7 +35,6 @@ def _from_user(self, key):
3435
return self.mapping.get(key, key)
3536

3637
def from_user(self, kwargs):
37-
print("from_user", kwargs, self)
3838
return {self._from_user(k): v for k, v in kwargs.items()}
3939

4040
def __repr__(self):
@@ -81,22 +81,16 @@ def _base_datetime(self):
8181
def _valid_datetime(self):
8282
return self._get("valid_datetime")
8383

84-
def _get(self, key, **kwargs):
84+
def get(self, key, astype=None, **kwargs):
8585

8686
if key in self._d:
87+
if astype is not None:
88+
return astype(self._d[key])
8789
return self._d[key]
8890

89-
if key.startswith("mars."):
90-
key = key[5:]
91-
if key not in self.MARS_KEYS:
92-
if kwargs.get("raise_on_missing", False):
93-
raise KeyError(f"Invalid key '{key}' in namespace='mars'")
94-
else:
95-
return kwargs.get("default", None)
96-
9791
key = self._mapping._from_user(key)
9892

99-
return super()._get(key, **kwargs)
93+
return super().get(key, astype=astype, **kwargs)
10094

10195

10296
class XArrayFieldGeography(Geography):
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
11+
import logging
12+
13+
LOG = logging.getLogger(__name__)
14+
15+
16+
def patch_attributes(ds, attributes):
17+
for name, value in attributes.items():
18+
variable = ds[name]
19+
variable.attrs.update(value)
20+
21+
return ds
22+
23+
24+
def patch_coordinates(ds, coordinates):
25+
for name in coordinates:
26+
ds = ds.assign_coords({name: ds[name]})
27+
28+
return ds
29+
30+
31+
PATCHES = {
32+
"attributes": patch_attributes,
33+
"coordinates": patch_coordinates,
34+
}
35+
36+
37+
def patch_dataset(ds, patch):
38+
for what, values in patch.items():
39+
if what not in PATCHES:
40+
raise ValueError(f"Unknown patch type {what!r}")
41+
42+
ds = PATCHES[what](ds, values)
43+
44+
return ds

src/anemoi/datasets/create/functions/sources/xarray/time.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,18 @@ def from_coordinates(cls, coordinates):
6262

6363
raise NotImplementedError(f"{len(date_coordinate)=} {len(time_coordinate)=} {len(step_coordinate)=}")
6464

65+
def select_valid_datetime(self, variable):
66+
raise NotImplementedError(f"{self.__class__.__name__}.select_valid_datetime()")
67+
6568

6669
class Constant(Time):
6770

6871
def fill_time_metadata(self, coords_values, metadata):
6972
return None
7073

74+
def select_valid_datetime(self, variable):
75+
return None
76+
7177

7278
class Analysis(Time):
7379

@@ -83,6 +89,9 @@ def fill_time_metadata(self, coords_values, metadata):
8389

8490
return valid_datetime
8591

92+
def select_valid_datetime(self, variable):
93+
return self.time_coordinate_name
94+
8695

8796
class ForecastFromValidTimeAndStep(Time):
8897

@@ -116,6 +125,9 @@ def fill_time_metadata(self, coords_values, metadata):
116125

117126
return valid_datetime
118127

128+
def select_valid_datetime(self, variable):
129+
return self.time_coordinate_name
130+
119131

120132
class ForecastFromValidTimeAndBaseTime(Time):
121133

@@ -138,6 +150,9 @@ def fill_time_metadata(self, coords_values, metadata):
138150

139151
return valid_datetime
140152

153+
def select_valid_datetime(self, variable):
154+
return self.time_coordinate_name
155+
141156

142157
class ForecastFromBaseTimeAndDate(Time):
143158

src/anemoi/datasets/create/functions/sources/xarray/variable.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,17 @@ def __init__(
3737
self.coordinates = coordinates
3838

3939
self._metadata = metadata.copy()
40-
self._metadata.update({"variable": variable.name})
40+
self._metadata.update({"variable": variable.name, "param": variable.name})
4141

4242
self.time = time
4343

4444
self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid)
4545
self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid}
4646
self.by_name = {c.variable.name: c for c in coordinates}
4747

48+
# We need that alias for the time dimension
49+
self._aliases = dict(valid_datetime="time")
50+
4851
self.length = math.prod(self.shape)
4952

5053
@property
@@ -96,15 +99,28 @@ def sel(self, missing, **kwargs):
9699

97100
k, v = kwargs.popitem()
98101

102+
user_provided_k = k
103+
104+
if k == "valid_datetime":
105+
# Ask the Time object to select the valid datetime
106+
k = self.time.select_valid_datetime(self)
107+
if k is None:
108+
return None
109+
99110
c = self.by_name.get(k)
100111

112+
# assert c is not None, f"Could not find coordinate {k} in {self.variable.name} {self.coordinates} {list(self.by_name)}"
113+
101114
if c is None:
102115
missing[k] = v
103116
return self.sel(missing, **kwargs)
104117

105118
i = c.index(v)
106119
if i is None:
107-
LOG.warning(f"Could not find {k}={v} in {c}")
120+
if k != user_provided_k:
121+
LOG.warning(f"Could not find {user_provided_k}={v} in {c} (alias of {k})")
122+
else:
123+
LOG.warning(f"Could not find {k}={v} in {c}")
108124
return None
109125

110126
coordinates = [x.reduced(i) if c is x else x for x in self.coordinates]

src/anemoi/datasets/create/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def to_datetime(*args, **kwargs):
5454

5555

5656
def make_list_int(value):
57+
# Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers.
58+
# Moved to anemoi.utils.humanize
59+
# replace with from anemoi.utils.humanize import make_list_int
60+
# when anemoi-utils is released and pyproject.toml is updated
5761
if isinstance(value, str):
5862
if "/" not in value:
5963
return [value]

src/anemoi/datasets/data/dataset.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,50 @@ def _compute_constant_fields_from_statistics(self):
512512
result.append(v)
513513

514514
return result
515+
516+
def plot(self, date, variable, member=0, **kwargs):
517+
"""For debugging purposes, plot a field.
518+
519+
Parameters
520+
----------
521+
date : int or datetime.datetime or numpy.datetime64 or str
522+
The date to plot.
523+
variable : int or str
524+
The variable to plot.
525+
member : int, optional
526+
The ensemble member to plot.
527+
528+
**kwargs:
529+
Additional arguments to pass to matplotlib.pyplot.tricontourf
530+
531+
532+
Returns
533+
-------
534+
matplotlib.pyplot.Axes
535+
"""
536+
537+
from anemoi.utils.devtools import plot_values
538+
from earthkit.data.utils.dates import to_datetime
539+
540+
if not isinstance(date, int):
541+
date = np.datetime64(to_datetime(date)).astype(self.dates[0].dtype)
542+
index = np.where(self.dates == date)[0]
543+
if len(index) == 0:
544+
raise ValueError(
545+
f"Date {date} not found in the dataset {self.dates[0]} to {self.dates[-1]} by {self.frequency}"
546+
)
547+
date_index = index[0]
548+
else:
549+
date_index = date
550+
551+
if isinstance(variable, int):
552+
variable_index = variable
553+
else:
554+
if variable not in self.variables:
555+
raise ValueError(f"Unknown variable {variable} (available: {self.variables})")
556+
557+
variable_index = self.name_to_index[variable]
558+
559+
values = self[date_index, variable_index, member]
560+
561+
return plot_values(values, self.latitudes, self.longitudes, **kwargs)

tests/xarray/test_zarr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_weatherbench():
7373
"levtype": "pl",
7474
}
7575

76-
fs = XarrayFieldList.from_xarray(ds, flavour)
76+
fs = XarrayFieldList.from_xarray(ds, flavour=flavour)
7777

7878
assert_field_list(
7979
fs,
@@ -116,7 +116,7 @@ def test_noaa_replay():
116116
"levtype": "pl",
117117
}
118118

119-
fs = XarrayFieldList.from_xarray(ds, flavour)
119+
fs = XarrayFieldList.from_xarray(ds, flavour=flavour)
120120

121121
assert_field_list(
122122
fs,
@@ -141,7 +141,7 @@ def test_planetary_computer_conus404():
141141
},
142142
}
143143

144-
fs = XarrayFieldList.from_xarray(ds, flavour)
144+
fs = XarrayFieldList.from_xarray(ds, flavour=flavour)
145145

146146
assert_field_list(
147147
fs,

0 commit comments

Comments
 (0)