Skip to content

Commit

Permalink
Fix for #155 and #116 (#159)
Browse files Browse the repository at this point in the history
* Fix for #155 and #116
  • Loading branch information
b8raoult authored Dec 18, 2024
1 parent a9e1f28 commit 871f262
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 17 deletions.
12 changes: 8 additions & 4 deletions src/anemoi/datasets/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,14 @@ def check_shape(cube, dates, dates_in_data):

check_shape(cube, dates, dates_in_data)

def check_dates_in_data(lst, lst2):
lst2 = [np.datetime64(_) for _ in lst2]
lst = [np.datetime64(_) for _ in lst]
assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2)
def check_dates_in_data(dates_in_data, requested_dates):
requested_dates = [np.datetime64(_) for _ in requested_dates]
dates_in_data = [np.datetime64(_) for _ in dates_in_data]
assert dates_in_data == requested_dates, (
"Dates in data are not the requested ones:",
dates_in_data,
requested_dates,
)

check_dates_in_data(dates_in_data, dates)

Expand Down
16 changes: 5 additions & 11 deletions src/anemoi/datasets/create/functions/sources/xarray/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class _MDMapping:
def __init__(self, variable):
self.variable = variable
self.time = variable.time
# Aliases
self.mapping = dict(param="variable")
for c in variable.coordinates:
for v in c.mars_names:
Expand All @@ -34,7 +35,6 @@ def _from_user(self, key):
return self.mapping.get(key, key)

def from_user(self, kwargs):
print("from_user", kwargs, self)
return {self._from_user(k): v for k, v in kwargs.items()}

def __repr__(self):
Expand Down Expand Up @@ -81,22 +81,16 @@ def _base_datetime(self):
def _valid_datetime(self):
return self._get("valid_datetime")

def _get(self, key, **kwargs):
def get(self, key, astype=None, **kwargs):

if key in self._d:
if astype is not None:
return astype(self._d[key])
return self._d[key]

if key.startswith("mars."):
key = key[5:]
if key not in self.MARS_KEYS:
if kwargs.get("raise_on_missing", False):
raise KeyError(f"Invalid key '{key}' in namespace='mars'")
else:
return kwargs.get("default", None)

key = self._mapping._from_user(key)

return super()._get(key, **kwargs)
return super().get(key, astype=astype, **kwargs)


class XArrayFieldGeography(Geography):
Expand Down
15 changes: 15 additions & 0 deletions src/anemoi/datasets/create/functions/sources/xarray/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,18 @@ def from_coordinates(cls, coordinates):

raise NotImplementedError(f"{len(date_coordinate)=} {len(time_coordinate)=} {len(step_coordinate)=}")

def select_valid_datetime(self, variable):
raise NotImplementedError(f"{self.__class__.__name__}.select_valid_datetime()")


class Constant(Time):

def fill_time_metadata(self, coords_values, metadata):
return None

def select_valid_datetime(self, variable):
return None


class Analysis(Time):

Expand All @@ -83,6 +89,9 @@ def fill_time_metadata(self, coords_values, metadata):

return valid_datetime

def select_valid_datetime(self, variable):
return self.time_coordinate_name


class ForecastFromValidTimeAndStep(Time):

Expand Down Expand Up @@ -116,6 +125,9 @@ def fill_time_metadata(self, coords_values, metadata):

return valid_datetime

def select_valid_datetime(self, variable):
return self.time_coordinate_name


class ForecastFromValidTimeAndBaseTime(Time):

Expand All @@ -138,6 +150,9 @@ def fill_time_metadata(self, coords_values, metadata):

return valid_datetime

def select_valid_datetime(self, variable):
return self.time_coordinate_name


class ForecastFromBaseTimeAndDate(Time):

Expand Down
20 changes: 18 additions & 2 deletions src/anemoi/datasets/create/functions/sources/xarray/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ def __init__(
self.coordinates = coordinates

self._metadata = metadata.copy()
self._metadata.update({"variable": variable.name})
self._metadata.update({"variable": variable.name, "param": variable.name})

self.time = time

self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid)
self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid}
self.by_name = {c.variable.name: c for c in coordinates}

# We need that alias for the time dimension
self._aliases = dict(valid_datetime="time")

self.length = math.prod(self.shape)

@property
Expand Down Expand Up @@ -96,15 +99,28 @@ def sel(self, missing, **kwargs):

k, v = kwargs.popitem()

user_provided_k = k

if k == "valid_datetime":
# Ask the Time object to select the valid datetime
k = self.time.select_valid_datetime(self)
if k is None:
return None

c = self.by_name.get(k)

# assert c is not None, f"Could not find coordinate {k} in {self.variable.name} {self.coordinates} {list(self.by_name)}"

if c is None:
missing[k] = v
return self.sel(missing, **kwargs)

i = c.index(v)
if i is None:
LOG.warning(f"Could not find {k}={v} in {c}")
if k != user_provided_k:
LOG.warning(f"Could not find {user_provided_k}={v} in {c} (alias of {k})")
else:
LOG.warning(f"Could not find {k}={v} in {c}")
return None

coordinates = [x.reduced(i) if c is x else x for x in self.coordinates]
Expand Down

0 comments on commit 871f262

Please sign in to comment.