Skip to content

Commit

Permalink
Feature/sub hourly (#23)
Browse files Browse the repository at this point in the history
* Move dates handling code to anemoi-utils

Support sub-hourly datasets

---------

Co-authored-by: Florian Pinault <[email protected]>
  • Loading branch information
b8raoult and floriankrb authored Aug 28, 2024
1 parent 63b3c0b commit ec22704
Show file tree
Hide file tree
Showing 25 changed files with 292 additions and 227 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Keep it human-readable, your future self will thank you!
- adds the reusable cd pypi workflow

### Changed

- Support sub-hourly datasets.
- Change negative variance detection to make it less restrictive
- Fix cutout bug that left some global grid points in the lam part

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dynamic = [
"version",
]
dependencies = [
"anemoi-utils[provenance]>=0.3.13",
"anemoi-utils[provenance]>=0.3.15",
"numpy",
"pyyaml",
"semantic-version",
Expand Down
14 changes: 12 additions & 2 deletions src/anemoi/datasets/create/functions/sources/xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,19 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])

if len(result) == 0:
LOG.warning(f"No data found for {dataset} and dates {dates}")
LOG.warning(f"No data found for {dataset} and dates {dates} and {kwargs}")
LOG.warning(f"Options: {options}")
LOG.warning(data)

for i, k in enumerate(fs):
a = ["valid_datetime", k.metadata("valid_datetime", default=None)]
for n in kwargs.keys():
a.extend([n, k.metadata(n, default=None)])
print([str(x) for x in a])

if i > 16:
break

# LOG.warning(data)

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Coordinate:
is_time = False
is_step = False
is_date = False
is_member = False

def __init__(self, variable):
self.variable = variable
Expand Down Expand Up @@ -201,8 +202,14 @@ def normalise(self, value):


class EnsembleCoordinate(Coordinate):
is_member = True
mars_names = ("number",)

def normalise(self, value):
if int(value) == value:
return int(value)
return value


class LongitudeCoordinate(Coordinate):
is_grid = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def to_numpy(self, flatten=False, dtype=None):
return values.reshape(self.shape)

def _make_metadata(self):
return XArrayMetadata(self, self.owner.mapping)
return XArrayMetadata(self)

def grid_points(self):
return self.owner.grid_points()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def sel(self, **kwargs):

for v in self.variables:

v.update_metadata_mapping(kwargs)

# First, select matching variables
# This will consume 'param' or 'variable' from kwargs
# and return the rest
Expand Down
22 changes: 21 additions & 1 deletion src/anemoi/datasets/create/functions/sources/xarray/flavour.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


from .coordinates import DateCoordinate
from .coordinates import EnsembleCoordinate
from .coordinates import LatitudeCoordinate
from .coordinates import LevelCoordinate
from .coordinates import LongitudeCoordinate
Expand Down Expand Up @@ -135,6 +136,17 @@ def _guess(self, c, coord):
if d is not None:
return d

d = self._is_number(
c,
axis=axis,
name=name,
long_name=long_name,
standard_name=standard_name,
units=units,
)
if d is not None:
return d

if c.shape in ((1,), tuple()):
return ScalarCoordinate(c)

Expand Down Expand Up @@ -249,9 +261,13 @@ def _is_level(self, c, *, axis, name, long_name, standard_name, units):
if standard_name == "depth":
return LevelCoordinate(c, "depth")

if name == "pressure":
if name == "vertical" and units == "hPa":
return LevelCoordinate(c, "pl")

def _is_number(self, c, *, axis, name, long_name, standard_name, units):
if name in ("realization", "number"):
return EnsembleCoordinate(c)


class FlavourCoordinateGuesser(CoordinateGuesser):
def __init__(self, ds, flavour):
Expand Down Expand Up @@ -328,3 +344,7 @@ def _levtype(self, c, *, axis, name, long_name, standard_name, units):
return self.flavour["levtype"]

raise NotImplementedError(f"levtype for {c=}")

def _is_number(self, c, *, axis, name, long_name, standard_name, units):
if self._match(c, "number", locals()):
return DateCoordinate(c)
56 changes: 27 additions & 29 deletions src/anemoi/datasets/create/functions/sources/xarray/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,53 +10,49 @@
import logging
from functools import cached_property

from anemoi.utils.dates import as_datetime
from earthkit.data.core.geography import Geography
from earthkit.data.core.metadata import RawMetadata
from earthkit.data.utils.dates import to_datetime
from earthkit.data.utils.projections import Projection

LOG = logging.getLogger(__name__)


class MDMapping:
class _MDMapping:

def __init__(self, mapping):
self.user_to_internal = mapping
def __init__(self, variable):
self.variable = variable
self.time = variable.time
self.mapping = dict(param="variable")
for c in variable.coordinates:
for v in c.mars_names:
assert v not in self.mapping, f"Duplicate key '{v}' in {c}"
self.mapping[v] = c.variable.name

def from_user(self, kwargs):
if isinstance(kwargs, str):
return self.user_to_internal.get(kwargs, kwargs)
return {self.user_to_internal.get(k, k): v for k, v in kwargs.items()}
def _from_user(self, key):
return self.mapping.get(key, key)

def __len__(self):
return len(self.user_to_internal)
def from_user(self, kwargs):
print("from_user", kwargs, self)
return {self._from_user(k): v for k, v in kwargs.items()}

def __repr__(self):
return f"MDMapping({self.user_to_internal})"
return f"MDMapping({self.mapping})"

def fill_time_metadata(self, field, md):
md["valid_datetime"] = as_datetime(self.variable.time.fill_time_metadata(field._md, md)).isoformat()


class XArrayMetadata(RawMetadata):
LS_KEYS = ["variable", "level", "valid_datetime", "units"]
NAMESPACES = ["default", "mars"]
MARS_KEYS = ["param", "step", "levelist", "levtype", "number", "date", "time"]

def __init__(self, field, mapping):
def __init__(self, field):
self._field = field
md = field._md.copy()

self._mapping = mapping
if mapping is None:
time_coord = [c for c in field.owner.coordinates if c.is_time]
if len(time_coord) == 1:
time_key = time_coord[0].name
else:
time_key = "time"
else:
time_key = mapping.from_user("valid_datetime")
self._time = to_datetime(md.pop(time_key))
self._field.owner.time.fill_time_metadata(self._time, md)
md["valid_datetime"] = self._time.isoformat()

self._mapping = _MDMapping(field.owner)
self._mapping.fill_time_metadata(field, md)
super().__init__(md)

@cached_property
Expand Down Expand Up @@ -88,10 +84,13 @@ def _base_datetime(self):
return self._field.forecast_reference_time

def _valid_datetime(self):
return self._time
return self._get("valid_datetime")

def _get(self, key, **kwargs):

if key in self._d:
return self._d[key]

if key.startswith("mars."):
key = key[5:]
if key not in self.MARS_KEYS:
Expand All @@ -100,8 +99,7 @@ def _get(self, key, **kwargs):
else:
return kwargs.get("default", None)

if self._mapping is not None:
key = self._mapping.from_user(key)
key = self._mapping._from_user(key)

return super()._get(key, **kwargs)

Expand Down
93 changes: 63 additions & 30 deletions src/anemoi/datasets/create/functions/sources/xarray/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,28 @@

import datetime

from anemoi.utils.dates import as_datetime


class Time:

@classmethod
def from_coordinates(cls, coordinates):
time_coordinate = [c for c in coordinates if c.is_time]
step_coordinate = [c for c in coordinates if c.is_step]
date_coordinate = [c for c in coordinates if c.is_date]

if len(date_coordinate) == 0 and len(time_coordinate) == 1 and len(step_coordinate) == 1:
return ForecasstFromValidTimeAndStep(step_coordinate[0])
return ForecastFromValidTimeAndStep(time_coordinate[0], step_coordinate[0])

if len(date_coordinate) == 0 and len(time_coordinate) == 1 and len(step_coordinate) == 0:
return Analysis()
return Analysis(time_coordinate[0])

if len(date_coordinate) == 0 and len(time_coordinate) == 0 and len(step_coordinate) == 0:
return Constant()

if len(date_coordinate) == 1 and len(time_coordinate) == 1 and len(step_coordinate) == 0:
return ForecastFromValidTimeAndBaseTime(date_coordinate[0])
return ForecastFromValidTimeAndBaseTime(date_coordinate[0], time_coordinate[0])

if len(date_coordinate) == 1 and len(time_coordinate) == 0 and len(step_coordinate) == 1:
return ForecastFromBaseTimeAndDate(date_coordinate[0], step_coordinate[0])
Expand All @@ -38,61 +41,91 @@ def from_coordinates(cls, coordinates):

class Constant(Time):

def fill_time_metadata(self, time, metadata):
metadata["date"] = time.strftime("%Y%m%d")
metadata["time"] = time.strftime("%H%M")
metadata["step"] = 0
def fill_time_metadata(self, coords_values, metadata):
raise NotImplementedError("Constant time not implemented")
# print("Constant", coords_values, metadata)
# metadata["date"] = time.strftime("%Y%m%d")
# metadata["time"] = time.strftime("%H%M")
# metadata["step"] = 0


class Analysis(Time):

def fill_time_metadata(self, time, metadata):
metadata["date"] = time.strftime("%Y%m%d")
metadata["time"] = time.strftime("%H%M")
def __init__(self, time_coordinate):
self.time_coordinate_name = time_coordinate.variable.name

def fill_time_metadata(self, coords_values, metadata):
valid_datetime = coords_values[self.time_coordinate_name]

metadata["date"] = as_datetime(valid_datetime).strftime("%Y%m%d")
metadata["time"] = as_datetime(valid_datetime).strftime("%H%M")
metadata["step"] = 0

return valid_datetime

class ForecasstFromValidTimeAndStep(Time):
def __init__(self, step_coordinate):
self.step_name = step_coordinate.variable.name

def fill_time_metadata(self, time, metadata):
step = metadata.pop(self.step_name)
class ForecastFromValidTimeAndStep(Time):

def __init__(self, time_coordinate, step_coordinate):
self.time_coordinate_name = time_coordinate.variable.name
self.step_coordinate_name = step_coordinate.variable.name

def fill_time_metadata(self, coords_values, metadata):
valid_datetime = coords_values[self.time_coordinate_name]
step = coords_values[self.step_coordinate_name]

assert isinstance(step, datetime.timedelta)
base = time - step
base_datetime = valid_datetime - step

hours = step.total_seconds() / 3600
assert int(hours) == hours

metadata["date"] = base.strftime("%Y%m%d")
metadata["time"] = base.strftime("%H%M")
metadata["date"] = as_datetime(base_datetime).strftime("%Y%m%d")
metadata["time"] = as_datetime(base_datetime).strftime("%H%M")
metadata["step"] = int(hours)
return valid_datetime


class ForecastFromValidTimeAndBaseTime(Time):
def __init__(self, date_coordinate):
self.date_coordinate = date_coordinate

def fill_time_metadata(self, time, metadata):
def __init__(self, date_coordinate, time_coordinate):
self.date_coordinate.name = date_coordinate.name
self.time_coordinate.name = time_coordinate.name

def fill_time_metadata(self, coords_values, metadata):
valid_datetime = coords_values[self.time_coordinate_name]
base_datetime = coords_values[self.date_coordinate_name]

step = time - self.date_coordinate
step = valid_datetime - base_datetime

hours = step.total_seconds() / 3600
assert int(hours) == hours

metadata["date"] = self.date_coordinate.single_value.strftime("%Y%m%d")
metadata["time"] = self.date_coordinate.single_value.strftime("%H%M")
metadata["date"] = as_datetime(base_datetime).strftime("%Y%m%d")
metadata["time"] = as_datetime(base_datetime).strftime("%H%M")
metadata["step"] = int(hours)

return valid_datetime


class ForecastFromBaseTimeAndDate(Time):

def __init__(self, date_coordinate, step_coordinate):
self.date_coordinate = date_coordinate
self.step_coordinate = step_coordinate
self.date_coordinate_name = date_coordinate.name
self.step_coordinate_name = step_coordinate.name

def fill_time_metadata(self, coords_values, metadata):

date = coords_values[self.date_coordinate_name]
step = coords_values[self.step_coordinate_name]
assert isinstance(step, datetime.timedelta)

metadata["date"] = as_datetime(date).strftime("%Y%m%d")
metadata["time"] = as_datetime(date).strftime("%H%M")

hours = step.total_seconds() / 3600

def fill_time_metadata(self, time, metadata):
metadata["date"] = time.strftime("%Y%m%d")
metadata["time"] = time.strftime("%H%M")
hours = metadata[self.step_coordinate.name].total_seconds() / 3600
assert int(hours) == hours
metadata["step"] = int(hours)

return date + step
Loading

0 comments on commit ec22704

Please sign in to comment.