Skip to content

Commit

Permalink
Merge branch 'develop' into feature/augment
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult authored Dec 18, 2024
2 parents 019546f + 22ae74c commit bda0dda
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 26 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ Keep it human-readable, your future self will thank you!
- Fix cutout slicing of grid dimension (#145)
- Use cKDTree instead of KDTree
- Implement 'complement' feature
- Update accumulations (#158)

- Add ability to patch xarrays (#160)

### Added

Expand Down
12 changes: 8 additions & 4 deletions src/anemoi/datasets/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,14 @@ def check_shape(cube, dates, dates_in_data):

check_shape(cube, dates, dates_in_data)

def check_dates_in_data(lst, lst2):
lst2 = [np.datetime64(_) for _ in lst2]
lst = [np.datetime64(_) for _ in lst]
assert lst == lst2, ("Dates in data are not the requested ones:", lst, lst2)
def check_dates_in_data(dates_in_data, requested_dates):
requested_dates = [np.datetime64(_) for _ in requested_dates]
dates_in_data = [np.datetime64(_) for _ in dates_in_data]
assert dates_in_data == requested_dates, (
"Dates in data are not the requested ones:",
dates_in_data,
requested_dates,
)

check_dates_in_data(dates_in_data, dates)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def check(what, ds, paths, **kwargs):
raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})")


def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs):
def load_one(emoji, context, dates, dataset, *, options={}, flavour=None, patch=None, **kwargs):
import xarray as xr

"""
Expand All @@ -54,10 +54,10 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs)
else:
data = xr.open_dataset(dataset, **options)

fs = XarrayFieldList.from_xarray(data, flavour)
fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch)

if len(dates) == 0:
return fs.sel(**kwargs)
result = fs.sel(**kwargs)
else:
result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates])

Expand Down
11 changes: 10 additions & 1 deletion src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from .field import EmptyFieldList
from .flavour import CoordinateGuesser
from .patch import patch_dataset
from .time import Time
from .variable import FilteredVariable
from .variable import Variable
Expand Down Expand Up @@ -49,7 +50,11 @@ def __getitem__(self, i):
raise IndexError(k)

@classmethod
def from_xarray(cls, ds, flavour=None):
def from_xarray(cls, ds, *, flavour=None, patch=None):

if patch is not None:
ds = patch_dataset(ds, patch)

variables = []

if isinstance(flavour, str):
Expand Down Expand Up @@ -83,6 +88,8 @@ def _skip_attr(v, attr_name):
_skip_attr(variable, "bounds")
_skip_attr(variable, "grid_mapping")

LOG.debug("Xarray data_vars: %s", ds.data_vars)

# Select only geographical variables
for name in ds.data_vars:

Expand All @@ -97,13 +104,15 @@ def _skip_attr(v, attr_name):
c = guess.guess(ds[coord], coord)
assert c, f"Could not guess coordinate for {coord}"
if coord not in variable.dims:
LOG.debug("%s: coord=%s (not a dimension): dims=%s", variable, coord, variable.dims)
c.is_dim = False
coordinates.append(c)

grid_coords = sum(1 for c in coordinates if c.is_grid and c.is_dim)
assert grid_coords <= 2

if grid_coords < 2:
LOG.debug("Skipping %s (not 2D): %s", variable, [(c, c.is_grid, c.is_dim) for c in coordinates])
continue

v = Variable(
Expand Down
16 changes: 5 additions & 11 deletions src/anemoi/datasets/create/functions/sources/xarray/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class _MDMapping:
def __init__(self, variable):
self.variable = variable
self.time = variable.time
# Aliases
self.mapping = dict(param="variable")
for c in variable.coordinates:
for v in c.mars_names:
Expand All @@ -34,7 +35,6 @@ def _from_user(self, key):
return self.mapping.get(key, key)

def from_user(self, kwargs):
print("from_user", kwargs, self)
return {self._from_user(k): v for k, v in kwargs.items()}

def __repr__(self):
Expand Down Expand Up @@ -81,22 +81,16 @@ def _base_datetime(self):
def _valid_datetime(self):
return self._get("valid_datetime")

def _get(self, key, **kwargs):
def get(self, key, astype=None, **kwargs):

if key in self._d:
if astype is not None:
return astype(self._d[key])
return self._d[key]

if key.startswith("mars."):
key = key[5:]
if key not in self.MARS_KEYS:
if kwargs.get("raise_on_missing", False):
raise KeyError(f"Invalid key '{key}' in namespace='mars'")
else:
return kwargs.get("default", None)

key = self._mapping._from_user(key)

return super()._get(key, **kwargs)
return super().get(key, astype=astype, **kwargs)


class XArrayFieldGeography(Geography):
Expand Down
44 changes: 44 additions & 0 deletions src/anemoi/datasets/create/functions/sources/xarray/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging

LOG = logging.getLogger(__name__)


def patch_attributes(ds, attributes):
for name, value in attributes.items():
variable = ds[name]
variable.attrs.update(value)

return ds


def patch_coordinates(ds, coordinates):
for name in coordinates:
ds = ds.assign_coords({name: ds[name]})

return ds


PATCHES = {
"attributes": patch_attributes,
"coordinates": patch_coordinates,
}


def patch_dataset(ds, patch):
for what, values in patch.items():
if what not in PATCHES:
raise ValueError(f"Unknown patch type {what!r}")

ds = PATCHES[what](ds, values)

return ds
15 changes: 15 additions & 0 deletions src/anemoi/datasets/create/functions/sources/xarray/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,18 @@ def from_coordinates(cls, coordinates):

raise NotImplementedError(f"{len(date_coordinate)=} {len(time_coordinate)=} {len(step_coordinate)=}")

def select_valid_datetime(self, variable):
raise NotImplementedError(f"{self.__class__.__name__}.select_valid_datetime()")


class Constant(Time):

def fill_time_metadata(self, coords_values, metadata):
return None

def select_valid_datetime(self, variable):
return None


class Analysis(Time):

Expand All @@ -83,6 +89,9 @@ def fill_time_metadata(self, coords_values, metadata):

return valid_datetime

def select_valid_datetime(self, variable):
return self.time_coordinate_name


class ForecastFromValidTimeAndStep(Time):

Expand Down Expand Up @@ -116,6 +125,9 @@ def fill_time_metadata(self, coords_values, metadata):

return valid_datetime

def select_valid_datetime(self, variable):
return self.time_coordinate_name


class ForecastFromValidTimeAndBaseTime(Time):

Expand All @@ -138,6 +150,9 @@ def fill_time_metadata(self, coords_values, metadata):

return valid_datetime

def select_valid_datetime(self, variable):
return self.time_coordinate_name


class ForecastFromBaseTimeAndDate(Time):

Expand Down
20 changes: 18 additions & 2 deletions src/anemoi/datasets/create/functions/sources/xarray/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@ def __init__(
self.coordinates = coordinates

self._metadata = metadata.copy()
self._metadata.update({"variable": variable.name})
self._metadata.update({"variable": variable.name, "param": variable.name})

self.time = time

self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid)
self.names = {c.variable.name: c for c in coordinates if c.is_dim and not c.scalar and not c.is_grid}
self.by_name = {c.variable.name: c for c in coordinates}

# We need that alias for the time dimension
self._aliases = dict(valid_datetime="time")

self.length = math.prod(self.shape)

@property
Expand Down Expand Up @@ -96,15 +99,28 @@ def sel(self, missing, **kwargs):

k, v = kwargs.popitem()

user_provided_k = k

if k == "valid_datetime":
# Ask the Time object to select the valid datetime
k = self.time.select_valid_datetime(self)
if k is None:
return None

c = self.by_name.get(k)

# assert c is not None, f"Could not find coordinate {k} in {self.variable.name} {self.coordinates} {list(self.by_name)}"

if c is None:
missing[k] = v
return self.sel(missing, **kwargs)

i = c.index(v)
if i is None:
LOG.warning(f"Could not find {k}={v} in {c}")
if k != user_provided_k:
LOG.warning(f"Could not find {user_provided_k}={v} in {c} (alias of {k})")
else:
LOG.warning(f"Could not find {k}={v} in {c}")
return None

coordinates = [x.reduced(i) if c is x else x for x in self.coordinates]
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/datasets/create/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def to_datetime(*args, **kwargs):


def make_list_int(value):
# Convert a string like "1/2/3" or "1/to/3" or "1/to/10/by/2" to a list of integers.
# Moved to anemoi.utils.humanize
# replace with from anemoi.utils.humanize import make_list_int
# when anemoi-utils is released and pyproject.toml is updated
if isinstance(value, str):
if "/" not in value:
return [value]
Expand Down
47 changes: 47 additions & 0 deletions src/anemoi/datasets/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,50 @@ def _compute_constant_fields_from_statistics(self):
result.append(v)

return result

def plot(self, date, variable, member=0, **kwargs):
"""For debugging purposes, plot a field.
Parameters
----------
date : int or datetime.datetime or numpy.datetime64 or str
The date to plot.
variable : int or str
The variable to plot.
member : int, optional
The ensemble member to plot.
**kwargs:
Additional arguments to pass to matplotlib.pyplot.tricontourf
Returns
-------
matplotlib.pyplot.Axes
"""

from anemoi.utils.devtools import plot_values
from earthkit.data.utils.dates import to_datetime

if not isinstance(date, int):
date = np.datetime64(to_datetime(date)).astype(self.dates[0].dtype)
index = np.where(self.dates == date)[0]
if len(index) == 0:
raise ValueError(
f"Date {date} not found in the dataset {self.dates[0]} to {self.dates[-1]} by {self.frequency}"
)
date_index = index[0]
else:
date_index = date

if isinstance(variable, int):
variable_index = variable
else:
if variable not in self.variables:
raise ValueError(f"Unknown variable {variable} (available: {self.variables})")

variable_index = self.name_to_index[variable]

values = self[date_index, variable_index, member]

return plot_values(values, self.latitudes, self.longitudes, **kwargs)
6 changes: 3 additions & 3 deletions tests/xarray/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_weatherbench():
"levtype": "pl",
}

fs = XarrayFieldList.from_xarray(ds, flavour)
fs = XarrayFieldList.from_xarray(ds, flavour=flavour)

assert_field_list(
fs,
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_noaa_replay():
"levtype": "pl",
}

fs = XarrayFieldList.from_xarray(ds, flavour)
fs = XarrayFieldList.from_xarray(ds, flavour=flavour)

assert_field_list(
fs,
Expand All @@ -141,7 +141,7 @@ def test_planetary_computer_conus404():
},
}

fs = XarrayFieldList.from_xarray(ds, flavour)
fs = XarrayFieldList.from_xarray(ds, flavour=flavour)

assert_field_list(
fs,
Expand Down

0 comments on commit bda0dda

Please sign in to comment.