diff --git a/CHANGELOG.md b/CHANGELOG.md index 16c19d589..56a0ea89d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,13 +15,14 @@ Keep it human-readable, your future self will thank you! - Fix metadata serialization handling of numpy.integer (#140) - Fix negative variance for constant variables (#148) - Fix cutout slicing of grid dimension (#145) -- update acumulation (#158) +- Use cKDTree instead of KDTree +- Implement 'complement' feature - Add ability to patch xarrays (#160) ### Added - Call filters from anemoi-transform -- make test optional when adls is not installed Pull request #110 +- Make test optional when adls is not installed Pull request #110 - Add wz_to_w, orog_to_z, and sum filters (#149) ## [0.5.8](https://github.com/ecmwf/anemoi-datasets/compare/0.5.7...0.5.8) - 2024-10-26 diff --git a/docs/using/code/complement1_.py b/docs/using/code/complement1_.py new file mode 100644 index 000000000..ad77234d4 --- /dev/null +++ b/docs/using/code/complement1_.py @@ -0,0 +1,6 @@ +open_dataset( + complement=dataset1, + source=dataset2, + what="variables", + interpolate="nearest", +) diff --git a/docs/using/code/complement2_.py b/docs/using/code/complement2_.py new file mode 100644 index 000000000..22ad78044 --- /dev/null +++ b/docs/using/code/complement2_.py @@ -0,0 +1,12 @@ +open_dataset( + cutout=[ + { + "complement": lam_dataset, + "source": global_dataset, + "interpolate": "nearest", + }, + { + "dataset": global_dataset, + }, + ] +) diff --git a/docs/using/code/complement3_.py b/docs/using/code/complement3_.py new file mode 100644 index 000000000..e20b30a39 --- /dev/null +++ b/docs/using/code/complement3_.py @@ -0,0 +1,4 @@ +open_dataset( + complement=dataset1, + source=dataset2, +) diff --git a/docs/using/combining.rst b/docs/using/combining.rst index 86f1082fa..1d6acd1d2 100644 --- a/docs/using/combining.rst +++ b/docs/using/combining.rst @@ -182,3 +182,32 @@ The difference can be seen at the boundary between the two grids: To debug the combination, you can pass `plot=True` to the `cutout` function (when running from a Notebook), of use `plot="prefix"` to save the plots to series of PNG files in the current directory. + +.. _complement: + +************ + complement +************ + +That feature will interpolate the variables of `dataset2` that are not +in `dataset1` to the grid of `dataset1` , add them to the list of +variable of `dataset1` and return the result. + +.. literalinclude:: code/complement1_.py + +Currently ``what`` can only be ``variables`` and can be omitted. + +The value for ``interpolate`` can be one of ``none`` (default) or +``nearest``. In the case of ``none``, the grids of the two datasets must +match. + +This feature was originally designed to be used in conjunction with +``cutout``, where `dataset1` is the lam, and `dataset2` is the global +dataset. + +.. literalinclude:: code/complement2_.py + +Another use case is to simply bring all non-overlapping variables of a +dataset into an other: + +.. literalinclude:: code/complement3_.py diff --git a/src/anemoi/datasets/create/functions/sources/xarray/field.py b/src/anemoi/datasets/create/functions/sources/xarray/field.py index 257e2932b..3f4c2a5e7 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/field.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/field.py @@ -120,6 +120,6 @@ def forecast_reference_time(self): def __repr__(self): return repr(self._metadata) - def _values(self): + def _values(self, dtype=None): # we don't use .values as this will download the data return self.selection diff --git a/src/anemoi/datasets/data/complement.py b/src/anemoi/datasets/data/complement.py new file mode 100644 index 000000000..ee324dd07 --- /dev/null +++ b/src/anemoi/datasets/data/complement.py @@ -0,0 +1,164 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import logging +from functools import cached_property + +from ..grids import nearest_grid_points +from .debug import Node +from .forwards import Combined +from .indexing import apply_index_to_slices_changes +from .indexing import index_to_slices +from .indexing import update_tuple +from .misc import _auto_adjust +from .misc import _open + +LOG = logging.getLogger(__name__) + + +class Complement(Combined): + + def __init__(self, target, source, what="variables", interpolation="nearest"): + super().__init__([target, source]) + + # We had the variables of dataset[1] to dataset[0] + # interpoated on the grid of dataset[0] + + self.target = target + self.source = source + + self._variables = [] + + # Keep the same order as the original dataset + for v in self.source.variables: + if v not in self.target.variables: + self._variables.append(v) + + if not self._variables: + raise ValueError("Augment: no missing variables") + + @property + def variables(self): + return self._variables + + @property + def name_to_index(self): + return {v: i for i, v in enumerate(self.variables)} + + @property + def shape(self): + shape = self.target.shape + return (shape[0], len(self._variables)) + shape[2:] + + @property + def variables_metadata(self): + return {k: v for k, v in self.source.variables_metadata.items() if k in self._variables} + + def check_same_variables(self, d1, d2): + pass + + @cached_property + def missing(self): + missing = self.source.missing.copy() + missing = missing | self.target.missing + return set(missing) + + def tree(self): + """Generates a hierarchical tree structure for the `Cutout` instance and + its associated datasets. + + Returns: + Node: A `Node` object representing the `Cutout` instance as the root + node, with each dataset in `self.datasets` represented as a child + node. + """ + return Node(self, [d.tree() for d in (self.target, self.source)]) + + def __getitem__(self, index): + if isinstance(index, (int, slice)): + index = (index, slice(None), slice(None), slice(None)) + return self._get_tuple(index) + + +class ComplementNone(Complement): + + def __init__(self, target, source): + super().__init__(target, source) + + def _get_tuple(self, index): + index, changes = index_to_slices(index, self.shape) + result = self.source[index] + return apply_index_to_slices_changes(result, changes) + + +class ComplementNearest(Complement): + + def __init__(self, target, source): + super().__init__(target, source) + + self._nearest_grid_points = nearest_grid_points( + self.source.latitudes, + self.source.longitudes, + self.target.latitudes, + self.target.longitudes, + ) + + def check_compatibility(self, d1, d2): + pass + + def _get_tuple(self, index): + variable_index = 1 + index, changes = index_to_slices(index, self.shape) + index, previous = update_tuple(index, variable_index, slice(None)) + source_index = [self.source.name_to_index[x] for x in self.variables[previous]] + source_data = self.source[index[0], source_index, index[2], ...] + target_data = source_data[..., self._nearest_grid_points] + + result = target_data[..., index[3]] + + return apply_index_to_slices_changes(result, changes) + + +def complement_factory(args, kwargs): + from .select import Select + + assert len(args) == 0, args + + source = kwargs.pop("source") + target = kwargs.pop("complement") + what = kwargs.pop("what", "variables") + interpolation = kwargs.pop("interpolation", "none") + + if what != "variables": + raise NotImplementedError(f"Complement what={what} not implemented") + + if interpolation not in ("none", "nearest"): + raise NotImplementedError(f"Complement method={interpolation} not implemented") + + source = _open(source) + target = _open(target) + # `select` is the same as `variables` + (source, target), kwargs = _auto_adjust([source, target], kwargs, exclude=["select"]) + + Class = { + None: ComplementNone, + "none": ComplementNone, + "nearest": ComplementNearest, + }[interpolation] + + complement = Class(target=target, source=source)._subset(**kwargs) + + # Will join the datasets along the variables axis + reorder = source.variables + complemented = _open([target, complement]) + ordered = ( + Select(complemented, complemented._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate() + ) + return ordered diff --git a/src/anemoi/datasets/data/join.py b/src/anemoi/datasets/data/join.py index 6b7de3e64..30d1f1379 100644 --- a/src/anemoi/datasets/data/join.py +++ b/src/anemoi/datasets/data/join.py @@ -118,6 +118,7 @@ def variables(self): def variables_metadata(self): result = {} variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))] + for d in self.datasets: md = d.variables_metadata for v in variables: diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index aad751f07..2d2493a3a 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -194,7 +194,7 @@ def _open(a): raise NotImplementedError(f"Unsupported argument: {type(a)}") -def _auto_adjust(datasets, kwargs): +def _auto_adjust(datasets, kwargs, exclude=None): if "adjust" not in kwargs: return datasets, kwargs @@ -214,6 +214,9 @@ def _auto_adjust(datasets, kwargs): for a in adjust_list: adjust_set.update(ALIASES.get(a, [a])) + if exclude is not None: + adjust_set -= set(exclude) + extra = set(adjust_set) - set(ALIASES["all"]) if extra: raise ValueError(f"Invalid adjust keys: {extra}") @@ -335,6 +338,12 @@ def _open_dataset(*args, **kwargs): assert not sets, sets return cutout_factory(args, kwargs).mutate() + if "complement" in kwargs: + from .complement import complement_factory + + assert not sets, sets + return complement_factory(args, kwargs).mutate() + for name in ("datasets", "dataset"): if name in kwargs: datasets = kwargs.pop(name) diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py index 9e3a10b6d..e5d58f0ce 100644 --- a/src/anemoi/datasets/grids.py +++ b/src/anemoi/datasets/grids.py @@ -152,7 +152,7 @@ def cutout_mask( plot=None, ): """Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]""" - from scipy.spatial import KDTree + from scipy.spatial import cKDTree # TODO: transform min_distance from lat/lon to xyz @@ -195,13 +195,13 @@ def cutout_mask( min_distance = min_distance_km / 6371.0 else: points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km] - distances, _ = KDTree(points).query(points, k=2) + distances, _ = cKDTree(points).query(points, k=2) min_distance = np.min(distances[:, 1]) LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km") - # Use a KDTree to find the nearest points - distances, indices = KDTree(lam_points).query(global_points, k=neighbours) + # Use a cKDTree to find the nearest points + distances, indices = cKDTree(lam_points).query(global_points, k=neighbours) # Centre of the Earth zero = np.array([0.0, 0.0, 0.0]) @@ -255,7 +255,7 @@ def thinning_mask( cropping_distance=2.0, ): """Return the list of points in [lats, lons] closest to [global_lats, global_lons]""" - from scipy.spatial import KDTree + from scipy.spatial import cKDTree assert global_lats.ndim == 1 assert global_lons.ndim == 1 @@ -291,20 +291,20 @@ def thinning_mask( xyx = latlon_to_xyz(lats, lons) points = np.array(xyx).transpose() - # Use a KDTree to find the nearest points - _, indices = KDTree(points).query(global_points, k=1) + # Use a cKDTree to find the nearest points + _, indices = cKDTree(points).query(global_points, k=1) return np.array([i for i in indices]) def outline(lats, lons, neighbours=5): - from scipy.spatial import KDTree + from scipy.spatial import cKDTree xyx = latlon_to_xyz(lats, lons) grid_points = np.array(xyx).transpose() - # Use a KDTree to find the nearest points - _, indices = KDTree(grid_points).query(grid_points, k=neighbours) + # Use a cKDTree to find the nearest points + _, indices = cKDTree(grid_points).query(grid_points, k=neighbours) # Centre of the Earth zero = np.array([0.0, 0.0, 0.0]) @@ -379,6 +379,19 @@ def serialise_mask(mask): return result +def nearest_grid_points(source_latitudes, source_longitudes, target_latitudes, target_longitudes): + from scipy.spatial import cKDTree + + source_xyz = latlon_to_xyz(source_latitudes, source_longitudes) + source_points = np.array(source_xyz).transpose() + + target_xyz = latlon_to_xyz(target_latitudes, target_longitudes) + target_points = np.array(target_xyz).transpose() + + _, indices = cKDTree(source_points).query(target_points, k=1) + return indices + + if __name__ == "__main__": global_lats, global_lons = np.meshgrid( np.linspace(90, -90, 90),