From 871f262af88959b95147d97b737dab13e661e606 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Wed, 18 Dec 2024 10:22:39 +0000 Subject: [PATCH] Fix for #155 and #116 (#159) * Fix for #155 and #116 --- src/anemoi/datasets/create/__init__.py | 12 +++++++---- .../functions/sources/xarray/metadata.py | 16 +++++---------- .../create/functions/sources/xarray/time.py | 15 ++++++++++++++ .../functions/sources/xarray/variable.py | 20 +++++++++++++++++-- 4 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index adf5b79f..d623ade2 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -622,10 +622,14 @@ def check_shape(cube, dates, dates_in_data): check_shape(cube, dates, dates_in_data) - def check_dates_in_data(lst, lst2): - lst2 = [np.datetime64(_) for _ in lst2] - lst = [np.datetime64(_) for _ in lst] - assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2) + def check_dates_in_data(dates_in_data, requested_dates): + requested_dates = [np.datetime64(_) for _ in requested_dates] + dates_in_data = [np.datetime64(_) for _ in dates_in_data] + assert dates_in_data == requested_dates, ( + "Dates in data are not the requested ones:", + dates_in_data, + requested_dates, + ) check_dates_in_data(dates_in_data, dates) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py index 6744ace9..ca574001 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py @@ -24,6 +24,7 @@ class _MDMapping: def __init__(self, variable): self.variable = variable self.time = variable.time + # Aliases self.mapping = dict(param="variable") for c in variable.coordinates: for v in c.mars_names: @@ -34,7 +35,6 @@ def _from_user(self, key): return self.mapping.get(key, key) def from_user(self, kwargs): - print("from_user", kwargs, self) return {self._from_user(k): v for k, v in kwargs.items()} def __repr__(self): @@ -81,22 +81,16 @@ def _base_datetime(self): def _valid_datetime(self): return self._get("valid_datetime") - def _get(self, key, **kwargs): + def get(self, key, astype=None, **kwargs): if key in self._d: + if astype is not None: + return astype(self._d[key]) return self._d[key] - if key.startswith("mars."): - key = key[5:] - if key not in self.MARS_KEYS: - if kwargs.get("raise_on_missing", False): - raise KeyError(f"Invalid key '{key}' in namespace='mars'") - else: - return kwargs.get("default", None) - key = self._mapping._from_user(key) - return super()._get(key, **kwargs) + return super().get(key, astype=astype, **kwargs) class XArrayFieldGeography(Geography): diff --git a/src/anemoi/datasets/create/functions/sources/xarray/time.py b/src/anemoi/datasets/create/functions/sources/xarray/time.py index dcee1d3f..533408ad 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/time.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/time.py @@ -62,12 +62,18 @@ def from_coordinates(cls, coordinates): raise NotImplementedError(f"{len(date_coordinate)=} {len(time_coordinate)=} {len(step_coordinate)=}") + def select_valid_datetime(self, variable): + raise NotImplementedError(f"{self.__class__.__name__}.select_valid_datetime()") + class Constant(Time): def fill_time_metadata(self, coords_values, metadata): return None + def select_valid_datetime(self, variable): + return None + class Analysis(Time): @@ -83,6 +89,9 @@ def fill_time_metadata(self, coords_values, metadata): return valid_datetime + def select_valid_datetime(self, variable): + return self.time_coordinate_name + class ForecastFromValidTimeAndStep(Time): @@ -116,6 +125,9 @@ def fill_time_metadata(self, coords_values, metadata): return valid_datetime + def select_valid_datetime(self, variable): + return self.time_coordinate_name + class ForecastFromValidTimeAndBaseTime(Time): @@ -138,6 +150,9 @@ def fill_time_metadata(self, coords_values, metadata): return valid_datetime + def select_valid_datetime(self, variable): + return self.time_coordinate_name + class ForecastFromBaseTimeAndDate(Time): diff --git a/src/anemoi/datasets/create/functions/sources/xarray/variable.py b/src/anemoi/datasets/create/functions/sources/xarray/variable.py index 7765c61f..e8086c5e 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/variable.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/variable.py @@ -37,7 +37,7 @@ def __init__( self.coordinates = coordinates self._metadata = metadata.copy() - self._metadata.update({"variable": variable.name}) + self._metadata.update({"variable": variable.name, "param": variable.name}) self.time = time @@ -45,6 +45,9 @@ def __init__( self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid} self.by_name = {c.variable.name: c for c in coordinates} + # We need that alias for the time dimension + self._aliases = dict(valid_datetime="time") + self.length = math.prod(self.shape) @property @@ -96,15 +99,28 @@ def sel(self, missing, **kwargs): k, v = kwargs.popitem() + user_provided_k = k + + if k == "valid_datetime": + # Ask the Time object to select the valid datetime + k = self.time.select_valid_datetime(self) + if k is None: + return None + c = self.by_name.get(k) + # assert c is not None, f"Could not find coordinate {k} in {self.variable.name} {self.coordinates} {list(self.by_name)}" + if c is None: missing[k] = v return self.sel(missing, **kwargs) i = c.index(v) if i is None: - LOG.warning(f"Could not find {k}={v} in {c}") + if k != user_provided_k: + LOG.warning(f"Could not find {user_provided_k}={v} in {c} (alias of {k})") + else: + LOG.warning(f"Could not find {k}={v} in {c}") return None coordinates = [x.reduced(i) if c is x else x for x in self.coordinates]