Skip to content

Commit 8fd1000

Browse files
b8raoultOpheliaMirallespre-commit-ci[bot]floriankrb
authored
Complement a dataset from an other, using interpolation to nearest neighbourg (#152)
* add augment feature * fix shape * Make missing date object lam compatible * Fix bug related to changes in earthkit-data * Add ordering of variables * Revert to uncommented + format * Remove unlinked code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove print messages * update documentation * Fix variables match with source data * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: omiralles <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Florian Pinault <[email protected]>
1 parent 22ae74c commit 8fd1000

File tree

10 files changed

+253
-14
lines changed

10 files changed

+253
-14
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ Keep it human-readable, your future self will thank you!
1515
- Fix metadata serialization handling of numpy.integer (#140)
1616
- Fix negative variance for constant variables (#148)
1717
- Fix cutout slicing of grid dimension (#145)
18-
- update acumulation (#158)
18+
- Use cKDTree instead of KDTree
19+
- Implement 'complement' feature
1920
- Add ability to patch xarrays (#160)
2021

2122
### Added
2223

2324
- Call filters from anemoi-transform
24-
- make test optional when adls is not installed Pull request #110
25+
- Make test optional when adls is not installed Pull request #110
2526
- Add wz_to_w, orog_to_z, and sum filters (#149)
2627

2728
## [0.5.8](https://github.com/ecmwf/anemoi-datasets/compare/0.5.7...0.5.8) - 2024-10-26

docs/using/code/complement1_.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
open_dataset(
2+
complement=dataset1,
3+
source=dataset2,
4+
what="variables",
5+
interpolate="nearest",
6+
)

docs/using/code/complement2_.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
open_dataset(
2+
cutout=[
3+
{
4+
"complement": lam_dataset,
5+
"source": global_dataset,
6+
"interpolate": "nearest",
7+
},
8+
{
9+
"dataset": global_dataset,
10+
},
11+
]
12+
)

docs/using/code/complement3_.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
open_dataset(
2+
complement=dataset1,
3+
source=dataset2,
4+
)

docs/using/combining.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,32 @@ The difference can be seen at the boundary between the two grids:
182182
To debug the combination, you can pass `plot=True` to the `cutout`
183183
function (when running from a Notebook), of use `plot="prefix"` to save
184184
the plots to series of PNG files in the current directory.
185+
186+
.. _complement:
187+
188+
************
189+
complement
190+
************
191+
192+
That feature will interpolate the variables of `dataset2` that are not
193+
in `dataset1` to the grid of `dataset1` , add them to the list of
194+
variable of `dataset1` and return the result.
195+
196+
.. literalinclude:: code/complement1_.py
197+
198+
Currently ``what`` can only be ``variables`` and can be omitted.
199+
200+
The value for ``interpolate`` can be one of ``none`` (default) or
201+
``nearest``. In the case of ``none``, the grids of the two datasets must
202+
match.
203+
204+
This feature was originally designed to be used in conjunction with
205+
``cutout``, where `dataset1` is the lam, and `dataset2` is the global
206+
dataset.
207+
208+
.. literalinclude:: code/complement2_.py
209+
210+
Another use case is to simply bring all non-overlapping variables of a
211+
dataset into an other:
212+
213+
.. literalinclude:: code/complement3_.py

src/anemoi/datasets/create/functions/sources/xarray/field.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,6 @@ def forecast_reference_time(self):
120120
def __repr__(self):
121121
return repr(self._metadata)
122122

123-
def _values(self):
123+
def _values(self, dtype=None):
124124
# we don't use .values as this will download the data
125125
return self.selection
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
11+
import logging
12+
from functools import cached_property
13+
14+
from ..grids import nearest_grid_points
15+
from .debug import Node
16+
from .forwards import Combined
17+
from .indexing import apply_index_to_slices_changes
18+
from .indexing import index_to_slices
19+
from .indexing import update_tuple
20+
from .misc import _auto_adjust
21+
from .misc import _open
22+
23+
LOG = logging.getLogger(__name__)
24+
25+
26+
class Complement(Combined):
27+
28+
def __init__(self, target, source, what="variables", interpolation="nearest"):
29+
super().__init__([target, source])
30+
31+
# We had the variables of dataset[1] to dataset[0]
32+
# interpoated on the grid of dataset[0]
33+
34+
self.target = target
35+
self.source = source
36+
37+
self._variables = []
38+
39+
# Keep the same order as the original dataset
40+
for v in self.source.variables:
41+
if v not in self.target.variables:
42+
self._variables.append(v)
43+
44+
if not self._variables:
45+
raise ValueError("Augment: no missing variables")
46+
47+
@property
48+
def variables(self):
49+
return self._variables
50+
51+
@property
52+
def name_to_index(self):
53+
return {v: i for i, v in enumerate(self.variables)}
54+
55+
@property
56+
def shape(self):
57+
shape = self.target.shape
58+
return (shape[0], len(self._variables)) + shape[2:]
59+
60+
@property
61+
def variables_metadata(self):
62+
return {k: v for k, v in self.source.variables_metadata.items() if k in self._variables}
63+
64+
def check_same_variables(self, d1, d2):
65+
pass
66+
67+
@cached_property
68+
def missing(self):
69+
missing = self.source.missing.copy()
70+
missing = missing | self.target.missing
71+
return set(missing)
72+
73+
def tree(self):
74+
"""Generates a hierarchical tree structure for the `Cutout` instance and
75+
its associated datasets.
76+
77+
Returns:
78+
Node: A `Node` object representing the `Cutout` instance as the root
79+
node, with each dataset in `self.datasets` represented as a child
80+
node.
81+
"""
82+
return Node(self, [d.tree() for d in (self.target, self.source)])
83+
84+
def __getitem__(self, index):
85+
if isinstance(index, (int, slice)):
86+
index = (index, slice(None), slice(None), slice(None))
87+
return self._get_tuple(index)
88+
89+
90+
class ComplementNone(Complement):
91+
92+
def __init__(self, target, source):
93+
super().__init__(target, source)
94+
95+
def _get_tuple(self, index):
96+
index, changes = index_to_slices(index, self.shape)
97+
result = self.source[index]
98+
return apply_index_to_slices_changes(result, changes)
99+
100+
101+
class ComplementNearest(Complement):
102+
103+
def __init__(self, target, source):
104+
super().__init__(target, source)
105+
106+
self._nearest_grid_points = nearest_grid_points(
107+
self.source.latitudes,
108+
self.source.longitudes,
109+
self.target.latitudes,
110+
self.target.longitudes,
111+
)
112+
113+
def check_compatibility(self, d1, d2):
114+
pass
115+
116+
def _get_tuple(self, index):
117+
variable_index = 1
118+
index, changes = index_to_slices(index, self.shape)
119+
index, previous = update_tuple(index, variable_index, slice(None))
120+
source_index = [self.source.name_to_index[x] for x in self.variables[previous]]
121+
source_data = self.source[index[0], source_index, index[2], ...]
122+
target_data = source_data[..., self._nearest_grid_points]
123+
124+
result = target_data[..., index[3]]
125+
126+
return apply_index_to_slices_changes(result, changes)
127+
128+
129+
def complement_factory(args, kwargs):
130+
from .select import Select
131+
132+
assert len(args) == 0, args
133+
134+
source = kwargs.pop("source")
135+
target = kwargs.pop("complement")
136+
what = kwargs.pop("what", "variables")
137+
interpolation = kwargs.pop("interpolation", "none")
138+
139+
if what != "variables":
140+
raise NotImplementedError(f"Complement what={what} not implemented")
141+
142+
if interpolation not in ("none", "nearest"):
143+
raise NotImplementedError(f"Complement method={interpolation} not implemented")
144+
145+
source = _open(source)
146+
target = _open(target)
147+
# `select` is the same as `variables`
148+
(source, target), kwargs = _auto_adjust([source, target], kwargs, exclude=["select"])
149+
150+
Class = {
151+
None: ComplementNone,
152+
"none": ComplementNone,
153+
"nearest": ComplementNearest,
154+
}[interpolation]
155+
156+
complement = Class(target=target, source=source)._subset(**kwargs)
157+
158+
# Will join the datasets along the variables axis
159+
reorder = source.variables
160+
complemented = _open([target, complement])
161+
ordered = (
162+
Select(complemented, complemented._reorder_to_columns(reorder), {"reoder": reorder})._subset(**kwargs).mutate()
163+
)
164+
return ordered

src/anemoi/datasets/data/join.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def variables(self):
118118
def variables_metadata(self):
119119
result = {}
120120
variables = [v for v in self.variables if not (v.startswith("(") and v.endswith(")"))]
121+
121122
for d in self.datasets:
122123
md = d.variables_metadata
123124
for v in variables:

src/anemoi/datasets/data/misc.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def _open(a):
194194
raise NotImplementedError(f"Unsupported argument: {type(a)}")
195195

196196

197-
def _auto_adjust(datasets, kwargs):
197+
def _auto_adjust(datasets, kwargs, exclude=None):
198198

199199
if "adjust" not in kwargs:
200200
return datasets, kwargs
@@ -214,6 +214,9 @@ def _auto_adjust(datasets, kwargs):
214214
for a in adjust_list:
215215
adjust_set.update(ALIASES.get(a, [a]))
216216

217+
if exclude is not None:
218+
adjust_set -= set(exclude)
219+
217220
extra = set(adjust_set) - set(ALIASES["all"])
218221
if extra:
219222
raise ValueError(f"Invalid adjust keys: {extra}")
@@ -335,6 +338,12 @@ def _open_dataset(*args, **kwargs):
335338
assert not sets, sets
336339
return cutout_factory(args, kwargs).mutate()
337340

341+
if "complement" in kwargs:
342+
from .complement import complement_factory
343+
344+
assert not sets, sets
345+
return complement_factory(args, kwargs).mutate()
346+
338347
for name in ("datasets", "dataset"):
339348
if name in kwargs:
340349
datasets = kwargs.pop(name)

src/anemoi/datasets/grids.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def cutout_mask(
152152
plot=None,
153153
):
154154
"""Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]"""
155-
from scipy.spatial import KDTree
155+
from scipy.spatial import cKDTree
156156

157157
# TODO: transform min_distance from lat/lon to xyz
158158

@@ -195,13 +195,13 @@ def cutout_mask(
195195
min_distance = min_distance_km / 6371.0
196196
else:
197197
points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km]
198-
distances, _ = KDTree(points).query(points, k=2)
198+
distances, _ = cKDTree(points).query(points, k=2)
199199
min_distance = np.min(distances[:, 1])
200200

201201
LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km")
202202

203-
# Use a KDTree to find the nearest points
204-
distances, indices = KDTree(lam_points).query(global_points, k=neighbours)
203+
# Use a cKDTree to find the nearest points
204+
distances, indices = cKDTree(lam_points).query(global_points, k=neighbours)
205205

206206
# Centre of the Earth
207207
zero = np.array([0.0, 0.0, 0.0])
@@ -255,7 +255,7 @@ def thinning_mask(
255255
cropping_distance=2.0,
256256
):
257257
"""Return the list of points in [lats, lons] closest to [global_lats, global_lons]"""
258-
from scipy.spatial import KDTree
258+
from scipy.spatial import cKDTree
259259

260260
assert global_lats.ndim == 1
261261
assert global_lons.ndim == 1
@@ -291,20 +291,20 @@ def thinning_mask(
291291
xyx = latlon_to_xyz(lats, lons)
292292
points = np.array(xyx).transpose()
293293

294-
# Use a KDTree to find the nearest points
295-
_, indices = KDTree(points).query(global_points, k=1)
294+
# Use a cKDTree to find the nearest points
295+
_, indices = cKDTree(points).query(global_points, k=1)
296296

297297
return np.array([i for i in indices])
298298

299299

300300
def outline(lats, lons, neighbours=5):
301-
from scipy.spatial import KDTree
301+
from scipy.spatial import cKDTree
302302

303303
xyx = latlon_to_xyz(lats, lons)
304304
grid_points = np.array(xyx).transpose()
305305

306-
# Use a KDTree to find the nearest points
307-
_, indices = KDTree(grid_points).query(grid_points, k=neighbours)
306+
# Use a cKDTree to find the nearest points
307+
_, indices = cKDTree(grid_points).query(grid_points, k=neighbours)
308308

309309
# Centre of the Earth
310310
zero = np.array([0.0, 0.0, 0.0])
@@ -379,6 +379,19 @@ def serialise_mask(mask):
379379
return result
380380

381381

382+
def nearest_grid_points(source_latitudes, source_longitudes, target_latitudes, target_longitudes):
383+
from scipy.spatial import cKDTree
384+
385+
source_xyz = latlon_to_xyz(source_latitudes, source_longitudes)
386+
source_points = np.array(source_xyz).transpose()
387+
388+
target_xyz = latlon_to_xyz(target_latitudes, target_longitudes)
389+
target_points = np.array(target_xyz).transpose()
390+
391+
_, indices = cKDTree(source_points).query(target_points, k=1)
392+
return indices
393+
394+
382395
if __name__ == "__main__":
383396
global_lats, global_lons = np.meshgrid(
384397
np.linspace(90, -90, 90),

0 commit comments

Comments
 (0)