Skip to content

Commit

Permalink
Complement a dataset from an other, using interpolation to nearest ne…
Browse files Browse the repository at this point in the history
…ighbourg (#152)

* add augment feature

* fix shape

* Make missing date object lam compatible

* Fix bug related to changes in earthkit-data

* Add ordering of variables

* Revert to uncommented + format

* Remove unlinked code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove print messages

* update documentation

* Fix variables match with source data

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: omiralles <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Florian Pinault <[email protected]>
  • Loading branch information
4 people authored Dec 18, 2024
1 parent 22ae74c commit 8fd1000
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 14 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ Keep it human-readable, your future self will thank you!
- Fix metadata serialization handling of numpy.integer (#140)
- Fix negative variance for constant variables (#148)
- Fix cutout slicing of grid dimension (#145)
- update acumulation (#158)
- Use cKDTree instead of KDTree
- Implement 'complement' feature
- Add ability to patch xarrays (#160)

### Added

- Call filters from anemoi-transform
- make test optional when adls is not installed Pull request #110
- Make test optional when adls is not installed Pull request #110
- Add wz_to_w, orog_to_z, and sum filters (#149)

## [0.5.8](https://github.com/ecmwf/anemoi-datasets/compare/0.5.7...0.5.8) - 2024-10-26
Expand Down
6 changes: 6 additions & 0 deletions docs/using/code/complement1_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
open_dataset(
complement=dataset1,
source=dataset2,
what="variables",
interpolate="nearest",
)
12 changes: 12 additions & 0 deletions docs/using/code/complement2_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
open_dataset(
cutout=[
{
"complement": lam_dataset,
"source": global_dataset,
"interpolate": "nearest",
},
{
"dataset": global_dataset,
},
]
)
4 changes: 4 additions & 0 deletions docs/using/code/complement3_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
open_dataset(
complement=dataset1,
source=dataset2,
)
29 changes: 29 additions & 0 deletions docs/using/combining.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,32 @@ The difference can be seen at the boundary between the two grids:
To debug the combination, you can pass `plot=True` to the `cutout`
function (when running from a Notebook), of use `plot="prefix"` to save
the plots to series of PNG files in the current directory.

.. _complement:

************
complement
************

That feature will interpolate the variables of `dataset2` that are not
in `dataset1` to the grid of `dataset1` , add them to the list of
variable of `dataset1` and return the result.

.. literalinclude:: code/complement1_.py

Currently ``what`` can only be ``variables`` and can be omitted.

The value for ``interpolate`` can be one of ``none`` (default) or
``nearest``. In the case of ``none``, the grids of the two datasets must
match.

This feature was originally designed to be used in conjunction with
``cutout``, where `dataset1` is the lam, and `dataset2` is the global
dataset.

.. literalinclude:: code/complement2_.py

Another use case is to simply bring all non-overlapping variables of a
dataset into an other:

.. literalinclude:: code/complement3_.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,6 @@ def forecast_reference_time(self):
def __repr__(self):
return repr(self._metadata)

def _values(self):
def _values(self, dtype=None):
# we don't use .values as this will download the data
return self.selection
164 changes: 164 additions & 0 deletions src/anemoi/datasets/data/complement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import logging
from functools import cached_property

from ..grids import nearest_grid_points
from .debug import Node
from .forwards import Combined
from .indexing import apply_index_to_slices_changes
from .indexing import index_to_slices
from .indexing import update_tuple
from .misc import _auto_adjust
from .misc import _open

LOG = logging.getLogger(__name__)


class Complement(Combined):

def __init__(self, target, source, what="variables", interpolation="nearest"):
super().__init__([target, source])

# We had the variables of dataset[1] to dataset[0]
# interpoated on the grid of dataset[0]

self.target = target
self.source = source

self._variables = []

# Keep the same order as the original dataset
for v in self.source.variables:
if v not in self.target.variables:
self._variables.append(v)

if not self._variables:
raise ValueError("Augment: no missing variables")

@property
def variables(self):
return self._variables

@property
def name_to_index(self):
return {v: i for i, v in enumerate(self.variables)}

@property
def shape(self):
shape = self.target.shape
return (shape[0], len(self._variables)) + shape[2:]

@property
def variables_metadata(self):
return {k: v for k, v in self.source.variables_metadata.items() if k in self._variables}

def check_same_variables(self, d1, d2):
pass

@cached_property
def missing(self):
missing = self.source.missing.copy()
missing = missing | self.target.missing
return set(missing)

def tree(self):
"""Generates a hierarchical tree structure for the `Cutout` instance and
its associated datasets.
Returns:
Node: A `Node` object representing the `Cutout` instance as the root
node, with each dataset in `self.datasets` represented as a child
node.
"""
return Node(self, [d.tree() for d in (self.target, self.source)])

def __getitem__(self, index):
if isinstance(index, (int, slice)):
index = (index, slice(None), slice(None), slice(None))
return self._get_tuple(index)


class ComplementNone(Complement):

def __init__(self, target, source):
super().__init__(target, source)

def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
result = self.source[index]
return apply_index_to_slices_changes(result, changes)


class ComplementNearest(Complement):

def __init__(self, target, source):
super().__init__(target, source)

self._nearest_grid_points = nearest_grid_points(
self.source.latitudes,
self.source.longitudes,
self.target.latitudes,
self.target.longitudes,
)

def check_compatibility(self, d1, d2):
pass

def _get_tuple(self, index):
variable_index = 1
index, changes = index_to_slices(index, self.shape)
index, previous = update_tuple(index, variable_index, slice(None))
source_index = [self.source.name_to_index[x] for x in self.variables[previous]]
source_data = self.source[index[0], source_index, index[2], ...]
target_data = source_data[..., self._nearest_grid_points]

result = target_data[..., index[3]]

return apply_index_to_slices_changes(result, changes)


def complement_factory(args, kwargs):
from .select import Select

assert len(args) == 0, args

source = kwargs.pop("source")
target = kwargs.pop("complement")
what = kwargs.pop("what", "variables")
interpolation = kwargs.pop("interpolation", "none")

if what != "variables":
raise NotImplementedError(f"Complement what={what} not implemented")

if interpolation not in ("none", "nearest"):
raise NotImplementedError(f"Complement method={interpolation} not implemented")

source = _open(source)
target = _open(target)
# `select` is the same as `variables`
(source, target), kwargs = _auto_adjust([source, target], kwargs, exclude=["select"])

Class = {
None: ComplementNone,
"none": ComplementNone,
"nearest": ComplementNearest,
}[interpolation]

complement = Class(target=target, source=source)._subset(**kwargs)

# Will join the datasets along the variables axis
reorder = source.variables
complemented = _open([target, complement])
ordered = (
Select(complemented, complemented._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate()
)
return ordered
1 change: 1 addition & 0 deletions src/anemoi/datasets/data/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def variables(self):
def variables_metadata(self):
result = {}
variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))]

for d in self.datasets:
md = d.variables_metadata
for v in variables:
Expand Down
11 changes: 10 additions & 1 deletion src/anemoi/datasets/data/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _open(a):
raise NotImplementedError(f"Unsupported argument: {type(a)}")


def _auto_adjust(datasets, kwargs):
def _auto_adjust(datasets, kwargs, exclude=None):

if "adjust" not in kwargs:
return datasets, kwargs
Expand All @@ -214,6 +214,9 @@ def _auto_adjust(datasets, kwargs):
for a in adjust_list:
adjust_set.update(ALIASES.get(a, [a]))

if exclude is not None:
adjust_set -= set(exclude)

extra = set(adjust_set) - set(ALIASES["all"])
if extra:
raise ValueError(f"Invalid adjust keys: {extra}")
Expand Down Expand Up @@ -335,6 +338,12 @@ def _open_dataset(*args, **kwargs):
assert not sets, sets
return cutout_factory(args, kwargs).mutate()

if "complement" in kwargs:
from .complement import complement_factory

assert not sets, sets
return complement_factory(args, kwargs).mutate()

for name in ("datasets", "dataset"):
if name in kwargs:
datasets = kwargs.pop(name)
Expand Down
33 changes: 23 additions & 10 deletions src/anemoi/datasets/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def cutout_mask(
plot=None,
):
"""Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]"""
from scipy.spatial import KDTree
from scipy.spatial import cKDTree

# TODO: transform min_distance from lat/lon to xyz

Expand Down Expand Up @@ -195,13 +195,13 @@ def cutout_mask(
min_distance = min_distance_km / 6371.0
else:
points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km]
distances, _ = KDTree(points).query(points, k=2)
distances, _ = cKDTree(points).query(points, k=2)
min_distance = np.min(distances[:, 1])

LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km")

# Use a KDTree to find the nearest points
distances, indices = KDTree(lam_points).query(global_points, k=neighbours)
# Use a cKDTree to find the nearest points
distances, indices = cKDTree(lam_points).query(global_points, k=neighbours)

# Centre of the Earth
zero = np.array([0.0, 0.0, 0.0])
Expand Down Expand Up @@ -255,7 +255,7 @@ def thinning_mask(
cropping_distance=2.0,
):
"""Return the list of points in [lats, lons] closest to [global_lats, global_lons]"""
from scipy.spatial import KDTree
from scipy.spatial import cKDTree

assert global_lats.ndim == 1
assert global_lons.ndim == 1
Expand Down Expand Up @@ -291,20 +291,20 @@ def thinning_mask(
xyx = latlon_to_xyz(lats, lons)
points = np.array(xyx).transpose()

# Use a KDTree to find the nearest points
_, indices = KDTree(points).query(global_points, k=1)
# Use a cKDTree to find the nearest points
_, indices = cKDTree(points).query(global_points, k=1)

return np.array([i for i in indices])


def outline(lats, lons, neighbours=5):
from scipy.spatial import KDTree
from scipy.spatial import cKDTree

xyx = latlon_to_xyz(lats, lons)
grid_points = np.array(xyx).transpose()

# Use a KDTree to find the nearest points
_, indices = KDTree(grid_points).query(grid_points, k=neighbours)
# Use a cKDTree to find the nearest points
_, indices = cKDTree(grid_points).query(grid_points, k=neighbours)

# Centre of the Earth
zero = np.array([0.0, 0.0, 0.0])
Expand Down Expand Up @@ -379,6 +379,19 @@ def serialise_mask(mask):
return result


def nearest_grid_points(source_latitudes, source_longitudes, target_latitudes, target_longitudes):
from scipy.spatial import cKDTree

source_xyz = latlon_to_xyz(source_latitudes, source_longitudes)
source_points = np.array(source_xyz).transpose()

target_xyz = latlon_to_xyz(target_latitudes, target_longitudes)
target_points = np.array(target_xyz).transpose()

_, indices = cKDTree(source_points).query(target_points, k=1)
return indices


if __name__ == "__main__":
global_lats, global_lons = np.meshgrid(
np.linspace(90, -90, 90),
Expand Down

0 comments on commit 8fd1000

Please sign in to comment.