From a9e1f2800b961d681801fb89740196e793858b64 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Wed, 18 Dec 2024 10:48:56 +0100 Subject: [PATCH 1/3] add comment (#161) --- src/anemoi/datasets/create/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/anemoi/datasets/create/utils.py b/src/anemoi/datasets/create/utils.py index 578e31a19..ce6df19e3 100644 --- a/src/anemoi/datasets/create/utils.py +++ b/src/anemoi/datasets/create/utils.py @@ -54,6 +54,10 @@ def to_datetime(*args, **kwargs): def make_list_int(value): + # Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers. + # Moved to anemoi.utils.humanize + # replace with from anemoi.utils.humanize import make_list_int + # when anemoi-utils is released and pyproject.toml is updated if isinstance(value, str): if "/" not in value: return [value] From 871f262af88959b95147d97b737dab13e661e606 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Wed, 18 Dec 2024 10:22:39 +0000 Subject: [PATCH 2/3] Fix for #155 and #116 (#159) * Fix for #155 and #116 --- src/anemoi/datasets/create/__init__.py | 12 +++++++---- .../functions/sources/xarray/metadata.py | 16 +++++---------- .../create/functions/sources/xarray/time.py | 15 ++++++++++++++ .../functions/sources/xarray/variable.py | 20 +++++++++++++++++-- 4 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index adf5b79f1..d623ade28 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -622,10 +622,14 @@ def check_shape(cube, dates, dates_in_data): check_shape(cube, dates, dates_in_data) - def check_dates_in_data(lst, lst2): - lst2 = [np.datetime64(_) for _ in lst2] - lst = [np.datetime64(_) for _ in lst] - assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2) + def check_dates_in_data(dates_in_data, requested_dates): + requested_dates = [np.datetime64(_) for _ in requested_dates] + dates_in_data = [np.datetime64(_) for _ in dates_in_data] + assert dates_in_data == requested_dates, ( + "Dates in data are not the requested ones:", + dates_in_data, + requested_dates, + ) check_dates_in_data(dates_in_data, dates) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py index 6744ace9a..ca574001d 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py @@ -24,6 +24,7 @@ class _MDMapping: def __init__(self, variable): self.variable = variable self.time = variable.time + # Aliases self.mapping = dict(param="variable") for c in variable.coordinates: for v in c.mars_names: @@ -34,7 +35,6 @@ def _from_user(self, key): return self.mapping.get(key, key) def from_user(self, kwargs): - print("from_user", kwargs, self) return {self._from_user(k): v for k, v in kwargs.items()} def __repr__(self): @@ -81,22 +81,16 @@ def _base_datetime(self): def _valid_datetime(self): return self._get("valid_datetime") - def _get(self, key, **kwargs): + def get(self, key, astype=None, **kwargs): if key in self._d: + if astype is not None: + return astype(self._d[key]) return self._d[key] - if key.startswith("mars."): - key = key[5:] - if key not in self.MARS_KEYS: - if kwargs.get("raise_on_missing", False): - raise KeyError(f"Invalid key '{key}' in namespace='mars'") - else: - return kwargs.get("default", None) - key = self._mapping._from_user(key) - return super()._get(key, **kwargs) + return super().get(key, astype=astype, **kwargs) class XArrayFieldGeography(Geography): diff --git a/src/anemoi/datasets/create/functions/sources/xarray/time.py b/src/anemoi/datasets/create/functions/sources/xarray/time.py index dcee1d3fb..533408adc 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/time.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/time.py @@ -62,12 +62,18 @@ def from_coordinates(cls, coordinates): raise NotImplementedError(f"{len(date_coordinate)=} {len(time_coordinate)=} {len(step_coordinate)=}") + def select_valid_datetime(self, variable): + raise NotImplementedError(f"{self.__class__.__name__}.select_valid_datetime()") + class Constant(Time): def fill_time_metadata(self, coords_values, metadata): return None + def select_valid_datetime(self, variable): + return None + class Analysis(Time): @@ -83,6 +89,9 @@ def fill_time_metadata(self, coords_values, metadata): return valid_datetime + def select_valid_datetime(self, variable): + return self.time_coordinate_name + class ForecastFromValidTimeAndStep(Time): @@ -116,6 +125,9 @@ def fill_time_metadata(self, coords_values, metadata): return valid_datetime + def select_valid_datetime(self, variable): + return self.time_coordinate_name + class ForecastFromValidTimeAndBaseTime(Time): @@ -138,6 +150,9 @@ def fill_time_metadata(self, coords_values, metadata): return valid_datetime + def select_valid_datetime(self, variable): + return self.time_coordinate_name + class ForecastFromBaseTimeAndDate(Time): diff --git a/src/anemoi/datasets/create/functions/sources/xarray/variable.py b/src/anemoi/datasets/create/functions/sources/xarray/variable.py index 7765c61f5..e8086c5ec 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/variable.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/variable.py @@ -37,7 +37,7 @@ def __init__( self.coordinates = coordinates self._metadata = metadata.copy() - self._metadata.update({"variable": variable.name}) + self._metadata.update({"variable": variable.name, "param": variable.name}) self.time = time @@ -45,6 +45,9 @@ def __init__( self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid} self.by_name = {c.variable.name: c for c in coordinates} + # We need that alias for the time dimension + self._aliases = dict(valid_datetime="time") + self.length = math.prod(self.shape) @property @@ -96,15 +99,28 @@ def sel(self, missing, **kwargs): k, v = kwargs.popitem() + user_provided_k = k + + if k == "valid_datetime": + # Ask the Time object to select the valid datetime + k = self.time.select_valid_datetime(self) + if k is None: + return None + c = self.by_name.get(k) + # assert c is not None, f"Could not find coordinate {k} in {self.variable.name} {self.coordinates} {list(self.by_name)}" + if c is None: missing[k] = v return self.sel(missing, **kwargs) i = c.index(v) if i is None: - LOG.warning(f"Could not find {k}={v} in {c}") + if k != user_provided_k: + LOG.warning(f"Could not find {user_provided_k}={v} in {c} (alias of {k})") + else: + LOG.warning(f"Could not find {k}={v} in {c}") return None coordinates = [x.reduced(i) if c is x else x for x in self.coordinates] From 22ae74c63733a684df14b52b0d2c0a5519a97f1c Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Wed, 18 Dec 2024 10:23:00 +0000 Subject: [PATCH 3/3] Add support for patching xarrays when creating datasets (#160) * Add support for patching xarrays when creating datasets * Update CHANGELOG.md --- CHANGELOG.md | 1 + .../functions/sources/xarray/__init__.py | 6 +-- .../functions/sources/xarray/fieldlist.py | 11 ++++- .../create/functions/sources/xarray/patch.py | 44 +++++++++++++++++ src/anemoi/datasets/data/dataset.py | 47 +++++++++++++++++++ tests/xarray/test_zarr.py | 6 +-- 6 files changed, 108 insertions(+), 7 deletions(-) create mode 100644 src/anemoi/datasets/create/functions/sources/xarray/patch.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 443d74638..16c19d589 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ Keep it human-readable, your future self will thank you! - Fix negative variance for constant variables (#148) - Fix cutout slicing of grid dimension (#145) - update acumulation (#158) +- Add ability to patch xarrays (#160) ### Added diff --git a/src/anemoi/datasets/create/functions/sources/xarray/__init__.py b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py index fda14c1f0..4a3229a05 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/__init__.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py @@ -29,7 +29,7 @@ def check(what, ds, paths, **kwargs): raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})") -def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs): +def load_one(emoji, context, dates, dataset, *, options={}, flavour=None, patch=None, **kwargs): import xarray as xr """ @@ -54,10 +54,10 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs) else: data = xr.open_dataset(dataset, **options) - fs = XarrayFieldList.from_xarray(data, flavour) + fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch) if len(dates) == 0: - return fs.sel(**kwargs) + result = fs.sel(**kwargs) else: result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates]) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py b/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py index 716f7b6be..305d01a4d 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py @@ -16,6 +16,7 @@ from .field import EmptyFieldList from .flavour import CoordinateGuesser +from .patch import patch_dataset from .time import Time from .variable import FilteredVariable from .variable import Variable @@ -49,7 +50,11 @@ def __getitem__(self, i): raise IndexError(k) @classmethod - def from_xarray(cls, ds, flavour=None): + def from_xarray(cls, ds, *, flavour=None, patch=None): + + if patch is not None: + ds = patch_dataset(ds, patch) + variables = [] if isinstance(flavour, str): @@ -83,6 +88,8 @@ def _skip_attr(v, attr_name): _skip_attr(variable, "bounds") _skip_attr(variable, "grid_mapping") + LOG.debug("Xarray data_vars: %s", ds.data_vars) + # Select only geographical variables for name in ds.data_vars: @@ -97,6 +104,7 @@ def _skip_attr(v, attr_name): c = guess.guess(ds[coord], coord) assert c, f"Could not guess coordinate for {coord}" if coord not in variable.dims: + LOG.debug("%s: coord=%s (not a dimension): dims=%s", variable, coord, variable.dims) c.is_dim = False coordinates.append(c) @@ -104,6 +112,7 @@ def _skip_attr(v, attr_name): assert grid_coords <= 2 if grid_coords < 2: + LOG.debug("Skipping %s (not 2D): %s", variable, [(c, c.is_grid, c.is_dim) for c in coordinates]) continue v = Variable( diff --git a/src/anemoi/datasets/create/functions/sources/xarray/patch.py b/src/anemoi/datasets/create/functions/sources/xarray/patch.py new file mode 100644 index 000000000..dbe2b59c7 --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/xarray/patch.py @@ -0,0 +1,44 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging + +LOG = logging.getLogger(__name__) + + +def patch_attributes(ds, attributes): + for name, value in attributes.items(): + variable = ds[name] + variable.attrs.update(value) + + return ds + + +def patch_coordinates(ds, coordinates): + for name in coordinates: + ds = ds.assign_coords({name: ds[name]}) + + return ds + + +PATCHES = { + "attributes": patch_attributes, + "coordinates": patch_coordinates, +} + + +def patch_dataset(ds, patch): + for what, values in patch.items(): + if what not in PATCHES: + raise ValueError(f"Unknown patch type {what!r}") + + ds = PATCHES[what](ds, values) + + return ds diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index a9bb6e4bb..fcb7a3842 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -512,3 +512,50 @@ def _compute_constant_fields_from_statistics(self): result.append(v) return result + + def plot(self, date, variable, member=0, **kwargs): + """For debugging purposes, plot a field. + + Parameters + ---------- + date : int or datetime.datetime or numpy.datetime64 or str + The date to plot. + variable : int or str + The variable to plot. + member : int, optional + The ensemble member to plot. + + **kwargs: + Additional arguments to pass to matplotlib.pyplot.tricontourf + + + Returns + ------- + matplotlib.pyplot.Axes + """ + + from anemoi.utils.devtools import plot_values + from earthkit.data.utils.dates import to_datetime + + if not isinstance(date, int): + date = np.datetime64(to_datetime(date)).astype(self.dates[0].dtype) + index = np.where(self.dates == date)[0] + if len(index) == 0: + raise ValueError( + f"Date {date} not found in the dataset {self.dates[0]} to {self.dates[-1]} by {self.frequency}" + ) + date_index = index[0] + else: + date_index = date + + if isinstance(variable, int): + variable_index = variable + else: + if variable not in self.variables: + raise ValueError(f"Unknown variable {variable} (available: {self.variables})") + + variable_index = self.name_to_index[variable] + + values = self[date_index, variable_index, member] + + return plot_values(values, self.latitudes, self.longitudes, **kwargs) diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index 5847063c2..e39885424 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -73,7 +73,7 @@ def test_weatherbench(): "levtype": "pl", } - fs = XarrayFieldList.from_xarray(ds, flavour) + fs = XarrayFieldList.from_xarray(ds, flavour=flavour) assert_field_list( fs, @@ -116,7 +116,7 @@ def test_noaa_replay(): "levtype": "pl", } - fs = XarrayFieldList.from_xarray(ds, flavour) + fs = XarrayFieldList.from_xarray(ds, flavour=flavour) assert_field_list( fs, @@ -141,7 +141,7 @@ def test_planetary_computer_conus404(): }, } - fs = XarrayFieldList.from_xarray(ds, flavour) + fs = XarrayFieldList.from_xarray(ds, flavour=flavour) assert_field_list( fs,