Skip to content

Commit e1ab0b8

Browse files
paulina-tPaulina Met.pre-commit-ci[bot]
authored
Feature/support multiple lams to the Cutout class (#113)
* Enhance Cutout class to support multiple LAMs with hierarchical masking. --------- Co-authored-by: Paulina Met. <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 32a2fa9 commit e1ab0b8

File tree

7 files changed

+1484
-62
lines changed

7 files changed

+1484
-62
lines changed

src/anemoi/datasets/data/grids.py

+218-62
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from functools import cached_property
1313

1414
import numpy as np
15+
from scipy.spatial import cKDTree
1516

1617
from .debug import Node
1718
from .debug import debug_indexing
@@ -142,95 +143,250 @@ def tree(self):
142143

143144

144145
class Cutout(GridsBase):
145-
def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False):
146-
from anemoi.datasets.grids import cutout_mask
147-
146+
def __init__(self, datasets, axis=3, cropping_distance=2.0, neighbours=5, min_distance_km=None, plot=None):
147+
"""Initializes a Cutout object for hierarchical management of Limited Area
148+
Models (LAMs) and a global dataset, handling overlapping regions.
149+
150+
Args:
151+
datasets (list): List of LAM and global datasets.
152+
axis (int): Concatenation axis, must be set to 3.
153+
cropping_distance (float): Distance threshold in degrees for
154+
cropping cutouts.
155+
neighbours (int): Number of neighboring points to consider when
156+
constructing masks.
157+
min_distance_km (float, optional): Minimum distance threshold in km
158+
between grid points.
159+
plot (bool, optional): Flag to enable or disable visualization
160+
plots.
161+
"""
148162
super().__init__(datasets, axis)
149-
assert len(datasets) == 2, "CutoutGrids requires two datasets"
163+
assert len(datasets) >= 2, "CutoutGrids requires at least two datasets"
150164
assert axis == 3, "CutoutGrids requires axis=3"
165+
assert cropping_distance >= 0, "cropping_distance must be a non-negative number"
166+
if min_distance_km is not None:
167+
assert min_distance_km >= 0, "min_distance_km must be a non-negative number"
168+
169+
self.lams = datasets[:-1] # Assume the last dataset is the global one
170+
self.globe = datasets[-1]
171+
self.axis = axis
172+
self.cropping_distance = cropping_distance
173+
self.neighbours = neighbours
174+
self.min_distance_km = min_distance_km
175+
self.plot = plot
176+
self.masks = [] # To store the masks for each LAM dataset
177+
self.global_mask = np.ones(self.globe.shape[-1], dtype=bool)
178+
179+
# Initialize cumulative masks
180+
self._initialize_masks()
181+
182+
def _initialize_masks(self):
183+
"""Generates hierarchical masks for each LAM dataset by excluding
184+
overlapping regions with previous LAMs and creating a global mask for
185+
the global dataset.
186+
187+
Raises:
188+
ValueError: If the global mask dimension does not match the global
189+
dataset grid points.
190+
"""
191+
from anemoi.datasets.grids import cutout_mask
151192

152-
# We assume that the LAM is the first dataset, and the global is the second
153-
# Note: the second fields does not really need to be global
154-
155-
self.lam, self.globe = datasets
156-
self.mask = cutout_mask(
157-
self.lam.latitudes,
158-
self.lam.longitudes,
159-
self.globe.latitudes,
160-
self.globe.longitudes,
161-
plot=plot,
162-
min_distance_km=min_distance_km,
163-
cropping_distance=cropping_distance,
164-
neighbours=neighbours,
165-
)
166-
assert len(self.mask) == self.globe.shape[3], (
167-
len(self.mask),
168-
self.globe.shape[3],
169-
)
193+
for i, lam in enumerate(self.lams):
194+
assert len(lam.shape) == len(
195+
self.globe.shape
196+
), "LAMs and global dataset must have the same number of dimensions"
197+
lam_lats = lam.latitudes
198+
lam_lons = lam.longitudes
199+
# Create a mask for the global dataset excluding all LAM points
200+
global_overlap_mask = cutout_mask(
201+
lam.latitudes,
202+
lam.longitudes,
203+
self.globe.latitudes,
204+
self.globe.longitudes,
205+
plot=False,
206+
min_distance_km=self.min_distance_km,
207+
cropping_distance=self.cropping_distance,
208+
neighbours=self.neighbours,
209+
)
210+
211+
# Ensure the mask dimensions match the global grid points
212+
if global_overlap_mask.shape[0] != self.globe.shape[-1]:
213+
raise ValueError("Global mask dimension does not match global dataset grid " "points.")
214+
self.global_mask[~global_overlap_mask] = False
215+
216+
# Create a mask for the LAM datasets hierarchically, excluding
217+
# points from previous LAMs
218+
lam_current_mask = np.ones(lam.shape[-1], dtype=bool)
219+
if i > 0:
220+
for j in range(i):
221+
prev_lam = self.lams[j]
222+
prev_lam_lats = prev_lam.latitudes
223+
prev_lam_lons = prev_lam.longitudes
224+
# Check for overlap by computing distances
225+
if self.has_overlap(prev_lam_lats, prev_lam_lons, lam_lats, lam_lons):
226+
lam_overlap_mask = cutout_mask(
227+
prev_lam_lats,
228+
prev_lam_lons,
229+
lam_lats,
230+
lam_lons,
231+
plot=False,
232+
min_distance_km=self.min_distance_km,
233+
cropping_distance=self.cropping_distance,
234+
neighbours=self.neighbours,
235+
)
236+
lam_current_mask[~lam_overlap_mask] = False
237+
self.masks.append(lam_current_mask)
238+
239+
def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0):
240+
"""Checks for overlapping points between two sets of latitudes and
241+
longitudes within a specified distance threshold.
242+
243+
Args:
244+
lats1, lons1 (np.ndarray): Latitude and longitude arrays for the
245+
first dataset.
246+
lats2, lons2 (np.ndarray): Latitude and longitude arrays for the
247+
second dataset.
248+
distance_threshold (float): Distance in degrees to consider as
249+
overlapping.
250+
251+
Returns:
252+
bool: True if any points overlap within the distance threshold,
253+
otherwise False.
254+
"""
255+
# Create KDTree for the first set of points
256+
tree = cKDTree(np.vstack((lats1, lons1)).T)
257+
258+
# Query the second set of points against the first tree
259+
distances, _ = tree.query(np.vstack((lats2, lons2)).T, k=1)
260+
261+
# Check if any distance is less than the specified threshold
262+
return np.any(distances < distance_threshold)
263+
264+
def __getitem__(self, index):
265+
"""Retrieves data from the masked LAMs and global dataset based on the
266+
given index.
267+
268+
Args:
269+
index (int or slice or tuple): Index specifying the data to
270+
retrieve.
271+
272+
Returns:
273+
np.ndarray: Data array from the masked datasets based on the index.
274+
"""
275+
if isinstance(index, (int, slice)):
276+
index = (index, slice(None), slice(None), slice(None))
277+
return self._get_tuple(index)
278+
279+
def _get_tuple(self, index):
280+
"""Helper method that applies masks and retrieves data from each dataset
281+
according to the specified index.
282+
283+
Args:
284+
index (tuple): Index specifying slices to retrieve data.
285+
286+
Returns:
287+
np.ndarray: Concatenated data array from all datasets based on the
288+
index.
289+
"""
290+
index, changes = index_to_slices(index, self.shape)
291+
# Select data from each LAM
292+
lam_data = [lam[index] for lam in self.lams]
293+
294+
# First apply spatial indexing on `self.globe` and then apply the mask
295+
globe_data_sliced = self.globe[index[:3]]
296+
globe_data = globe_data_sliced[..., self.global_mask]
297+
298+
# Concatenate LAM data with global data
299+
result = np.concatenate(lam_data + [globe_data], axis=self.axis)
300+
return apply_index_to_slices_changes(result, changes)
170301

171302
def collect_supporting_arrays(self, collected, *path):
172-
collected.append((path, "cutout_mask", self.mask))
303+
"""Collects supporting arrays, including masks for each LAM and the global
304+
dataset.
305+
306+
Args:
307+
collected (list): List to which the supporting arrays are appended.
308+
*path: Variable length argument list specifying the paths for the masks.
309+
"""
310+
# Append masks for each LAM
311+
for i, (lam, mask) in enumerate(zip(self.lams, self.masks)):
312+
collected.append((path + (f"lam_{i}",), "cutout_mask", mask))
313+
314+
# Append the global mask
315+
collected.append((path + ("global",), "cutout_mask", self.global_mask))
173316

174317
@cached_property
175318
def shape(self):
176-
shape = self.lam.shape
177-
# Number of non-zero masked values in the globe dataset
178-
nb_globe = np.count_nonzero(self.mask)
179-
return shape[:-1] + (shape[-1] + nb_globe,)
319+
"""Returns the shape of the Cutout, accounting for retained grid points
320+
across all LAMs and the global dataset.
321+
322+
Returns:
323+
tuple: Shape of the concatenated masked datasets.
324+
"""
325+
shapes = [np.sum(mask) for mask in self.masks]
326+
global_shape = np.sum(self.global_mask)
327+
return tuple(self.lams[0].shape[:-1] + (sum(shapes) + global_shape,))
180328

181329
def check_same_resolution(self, d1, d2):
182330
# Turned off because we are combining different resolutions
183331
pass
184332

185333
@property
186-
def latitudes(self):
187-
return np.concatenate([self.lam.latitudes, self.globe.latitudes[self.mask]])
334+
def grids(self):
335+
"""Returns the number of grid points for each LAM and the global dataset
336+
after applying masks.
188337
189-
@property
190-
def longitudes(self):
191-
return np.concatenate([self.lam.longitudes, self.globe.longitudes[self.mask]])
338+
Returns:
339+
tuple: Count of retained grid points for each dataset.
340+
"""
341+
grids = [np.sum(mask) for mask in self.masks]
342+
grids.append(np.sum(self.global_mask))
343+
return tuple(grids)
192344

193-
def __getitem__(self, index):
194-
if isinstance(index, (int, slice)):
195-
index = (index, slice(None), slice(None), slice(None))
196-
return self._get_tuple(index)
345+
@property
346+
def latitudes(self):
347+
"""Returns the concatenated latitudes of each LAM and the global dataset
348+
after applying masks.
197349
198-
@debug_indexing
199-
@expand_list_indexing
200-
def _get_tuple(self, index):
201-
assert self.axis >= len(index) or index[self.axis] == slice(
202-
None
203-
), f"No support for selecting a subset of the 1D values {index} ({self.tree()})"
204-
index, changes = index_to_slices(index, self.shape)
350+
Returns:
351+
np.ndarray: Concatenated latitude array for the masked datasets.
352+
"""
353+
lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)])
205354

206-
# In case index_to_slices has changed the last slice
207-
index, _ = update_tuple(index, self.axis, slice(None))
355+
assert (
356+
len(lam_latitudes) + len(self.globe.latitudes[self.global_mask]) == self.shape[-1]
357+
), "Mismatch in number of latitudes"
208358

209-
lam_data = self.lam[index]
210-
globe_data = self.globe[index]
359+
latitudes = np.concatenate([lam_latitudes, self.globe.latitudes[self.global_mask]])
360+
return latitudes
211361

212-
globe_data = globe_data[:, :, :, self.mask]
362+
@property
363+
def longitudes(self):
364+
"""Returns the concatenated longitudes of each LAM and the global dataset
365+
after applying masks.
213366
214-
result = np.concatenate([lam_data, globe_data], axis=self.axis)
367+
Returns:
368+
np.ndarray: Concatenated longitude array for the masked datasets.
369+
"""
370+
lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)])
215371

216-
return apply_index_to_slices_changes(result, changes)
372+
assert (
373+
len(lam_longitudes) + len(self.globe.longitudes[self.global_mask]) == self.shape[-1]
374+
), "Mismatch in number of longitudes"
217375

218-
@property
219-
def grids(self):
220-
for d in self.datasets:
221-
if len(d.grids) > 1:
222-
raise NotImplementedError("CutoutGrids does not support multi-grids datasets as inputs")
223-
shape = self.lam.shape
224-
return (shape[-1], self.shape[-1] - shape[-1])
376+
longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]])
377+
return longitudes
225378

226379
def tree(self):
380+
"""Generates a hierarchical tree structure for the `Cutout` instance and
381+
its associated datasets.
382+
383+
Returns:
384+
Node: A `Node` object representing the `Cutout` instance as the root
385+
node, with each dataset in `self.datasets` represented as a child
386+
node.
387+
"""
227388
return Node(self, [d.tree() for d in self.datasets])
228389

229-
# def metadata_specific(self):
230-
# return super().metadata_specific(
231-
# mask=serialise_mask(self.mask),
232-
# )
233-
234390

235391
def grids_factory(args, kwargs):
236392
if "ensemble" in kwargs:

tools/grids/grids3.yaml

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
common:
2+
mars_request: &mars_request
3+
expver: "0001"
4+
grid: 0.25/0.25
5+
area: [40, 25, 20, 60]
6+
rotation: [-20, -40]
7+
8+
dates:
9+
start: 2024-01-01 00:00:00
10+
end: 2024-01-01 18:00:00
11+
frequency: 6h
12+
13+
input:
14+
join:
15+
- mars:
16+
<<: *mars_request
17+
param: [2t, 10u, 10v, lsm]
18+
levtype: sfc
19+
stream: oper
20+
type: an
21+
- mars:
22+
<<: *mars_request
23+
param: [q, t, z]
24+
levtype: pl
25+
level: [50, 100]
26+
stream: oper
27+
type: an
28+
- accumulations:
29+
<<: *mars_request
30+
levtype: sfc
31+
param: [cp, tp]
32+
- forcings:
33+
template: ${input.join.0.mars}
34+
param:
35+
- cos_latitude
36+
- sin_latitude
37+
38+
output:
39+
order_by: [valid_datetime, param_level, number]
40+
remapping:
41+
param_level: "{param}_{levelist}"
42+
statistics: param_level

0 commit comments

Comments
 (0)