|
12 | 12 | from functools import cached_property
|
13 | 13 |
|
14 | 14 | import numpy as np
|
| 15 | +from scipy.spatial import cKDTree |
15 | 16 |
|
16 | 17 | from .debug import Node
|
17 | 18 | from .debug import debug_indexing
|
@@ -142,95 +143,250 @@ def tree(self):
|
142 | 143 |
|
143 | 144 |
|
144 | 145 | class Cutout(GridsBase):
|
145 |
| - def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False): |
146 |
| - from anemoi.datasets.grids import cutout_mask |
147 |
| - |
| 146 | + def __init__(self, datasets, axis=3, cropping_distance=2.0, neighbours=5, min_distance_km=None, plot=None): |
| 147 | + """Initializes a Cutout object for hierarchical management of Limited Area |
| 148 | + Models (LAMs) and a global dataset, handling overlapping regions. |
| 149 | +
|
| 150 | + Args: |
| 151 | + datasets (list): List of LAM and global datasets. |
| 152 | + axis (int): Concatenation axis, must be set to 3. |
| 153 | + cropping_distance (float): Distance threshold in degrees for |
| 154 | + cropping cutouts. |
| 155 | + neighbours (int): Number of neighboring points to consider when |
| 156 | + constructing masks. |
| 157 | + min_distance_km (float, optional): Minimum distance threshold in km |
| 158 | + between grid points. |
| 159 | + plot (bool, optional): Flag to enable or disable visualization |
| 160 | + plots. |
| 161 | + """ |
148 | 162 | super().__init__(datasets, axis)
|
149 |
| - assert len(datasets) == 2, "CutoutGrids requires two datasets" |
| 163 | + assert len(datasets) >= 2, "CutoutGrids requires at least two datasets" |
150 | 164 | assert axis == 3, "CutoutGrids requires axis=3"
|
| 165 | + assert cropping_distance >= 0, "cropping_distance must be a non-negative number" |
| 166 | + if min_distance_km is not None: |
| 167 | + assert min_distance_km >= 0, "min_distance_km must be a non-negative number" |
| 168 | + |
| 169 | + self.lams = datasets[:-1] # Assume the last dataset is the global one |
| 170 | + self.globe = datasets[-1] |
| 171 | + self.axis = axis |
| 172 | + self.cropping_distance = cropping_distance |
| 173 | + self.neighbours = neighbours |
| 174 | + self.min_distance_km = min_distance_km |
| 175 | + self.plot = plot |
| 176 | + self.masks = [] # To store the masks for each LAM dataset |
| 177 | + self.global_mask = np.ones(self.globe.shape[-1], dtype=bool) |
| 178 | + |
| 179 | + # Initialize cumulative masks |
| 180 | + self._initialize_masks() |
| 181 | + |
| 182 | + def _initialize_masks(self): |
| 183 | + """Generates hierarchical masks for each LAM dataset by excluding |
| 184 | + overlapping regions with previous LAMs and creating a global mask for |
| 185 | + the global dataset. |
| 186 | +
|
| 187 | + Raises: |
| 188 | + ValueError: If the global mask dimension does not match the global |
| 189 | + dataset grid points. |
| 190 | + """ |
| 191 | + from anemoi.datasets.grids import cutout_mask |
151 | 192 |
|
152 |
| - # We assume that the LAM is the first dataset, and the global is the second |
153 |
| - # Note: the second fields does not really need to be global |
154 |
| - |
155 |
| - self.lam, self.globe = datasets |
156 |
| - self.mask = cutout_mask( |
157 |
| - self.lam.latitudes, |
158 |
| - self.lam.longitudes, |
159 |
| - self.globe.latitudes, |
160 |
| - self.globe.longitudes, |
161 |
| - plot=plot, |
162 |
| - min_distance_km=min_distance_km, |
163 |
| - cropping_distance=cropping_distance, |
164 |
| - neighbours=neighbours, |
165 |
| - ) |
166 |
| - assert len(self.mask) == self.globe.shape[3], ( |
167 |
| - len(self.mask), |
168 |
| - self.globe.shape[3], |
169 |
| - ) |
| 193 | + for i, lam in enumerate(self.lams): |
| 194 | + assert len(lam.shape) == len( |
| 195 | + self.globe.shape |
| 196 | + ), "LAMs and global dataset must have the same number of dimensions" |
| 197 | + lam_lats = lam.latitudes |
| 198 | + lam_lons = lam.longitudes |
| 199 | + # Create a mask for the global dataset excluding all LAM points |
| 200 | + global_overlap_mask = cutout_mask( |
| 201 | + lam.latitudes, |
| 202 | + lam.longitudes, |
| 203 | + self.globe.latitudes, |
| 204 | + self.globe.longitudes, |
| 205 | + plot=False, |
| 206 | + min_distance_km=self.min_distance_km, |
| 207 | + cropping_distance=self.cropping_distance, |
| 208 | + neighbours=self.neighbours, |
| 209 | + ) |
| 210 | + |
| 211 | + # Ensure the mask dimensions match the global grid points |
| 212 | + if global_overlap_mask.shape[0] != self.globe.shape[-1]: |
| 213 | + raise ValueError("Global mask dimension does not match global dataset grid " "points.") |
| 214 | + self.global_mask[~global_overlap_mask] = False |
| 215 | + |
| 216 | + # Create a mask for the LAM datasets hierarchically, excluding |
| 217 | + # points from previous LAMs |
| 218 | + lam_current_mask = np.ones(lam.shape[-1], dtype=bool) |
| 219 | + if i > 0: |
| 220 | + for j in range(i): |
| 221 | + prev_lam = self.lams[j] |
| 222 | + prev_lam_lats = prev_lam.latitudes |
| 223 | + prev_lam_lons = prev_lam.longitudes |
| 224 | + # Check for overlap by computing distances |
| 225 | + if self.has_overlap(prev_lam_lats, prev_lam_lons, lam_lats, lam_lons): |
| 226 | + lam_overlap_mask = cutout_mask( |
| 227 | + prev_lam_lats, |
| 228 | + prev_lam_lons, |
| 229 | + lam_lats, |
| 230 | + lam_lons, |
| 231 | + plot=False, |
| 232 | + min_distance_km=self.min_distance_km, |
| 233 | + cropping_distance=self.cropping_distance, |
| 234 | + neighbours=self.neighbours, |
| 235 | + ) |
| 236 | + lam_current_mask[~lam_overlap_mask] = False |
| 237 | + self.masks.append(lam_current_mask) |
| 238 | + |
| 239 | + def has_overlap(self, lats1, lons1, lats2, lons2, distance_threshold=1.0): |
| 240 | + """Checks for overlapping points between two sets of latitudes and |
| 241 | + longitudes within a specified distance threshold. |
| 242 | +
|
| 243 | + Args: |
| 244 | + lats1, lons1 (np.ndarray): Latitude and longitude arrays for the |
| 245 | + first dataset. |
| 246 | + lats2, lons2 (np.ndarray): Latitude and longitude arrays for the |
| 247 | + second dataset. |
| 248 | + distance_threshold (float): Distance in degrees to consider as |
| 249 | + overlapping. |
| 250 | +
|
| 251 | + Returns: |
| 252 | + bool: True if any points overlap within the distance threshold, |
| 253 | + otherwise False. |
| 254 | + """ |
| 255 | + # Create KDTree for the first set of points |
| 256 | + tree = cKDTree(np.vstack((lats1, lons1)).T) |
| 257 | + |
| 258 | + # Query the second set of points against the first tree |
| 259 | + distances, _ = tree.query(np.vstack((lats2, lons2)).T, k=1) |
| 260 | + |
| 261 | + # Check if any distance is less than the specified threshold |
| 262 | + return np.any(distances < distance_threshold) |
| 263 | + |
| 264 | + def __getitem__(self, index): |
| 265 | + """Retrieves data from the masked LAMs and global dataset based on the |
| 266 | + given index. |
| 267 | +
|
| 268 | + Args: |
| 269 | + index (int or slice or tuple): Index specifying the data to |
| 270 | + retrieve. |
| 271 | +
|
| 272 | + Returns: |
| 273 | + np.ndarray: Data array from the masked datasets based on the index. |
| 274 | + """ |
| 275 | + if isinstance(index, (int, slice)): |
| 276 | + index = (index, slice(None), slice(None), slice(None)) |
| 277 | + return self._get_tuple(index) |
| 278 | + |
| 279 | + def _get_tuple(self, index): |
| 280 | + """Helper method that applies masks and retrieves data from each dataset |
| 281 | + according to the specified index. |
| 282 | +
|
| 283 | + Args: |
| 284 | + index (tuple): Index specifying slices to retrieve data. |
| 285 | +
|
| 286 | + Returns: |
| 287 | + np.ndarray: Concatenated data array from all datasets based on the |
| 288 | + index. |
| 289 | + """ |
| 290 | + index, changes = index_to_slices(index, self.shape) |
| 291 | + # Select data from each LAM |
| 292 | + lam_data = [lam[index] for lam in self.lams] |
| 293 | + |
| 294 | + # First apply spatial indexing on `self.globe` and then apply the mask |
| 295 | + globe_data_sliced = self.globe[index[:3]] |
| 296 | + globe_data = globe_data_sliced[..., self.global_mask] |
| 297 | + |
| 298 | + # Concatenate LAM data with global data |
| 299 | + result = np.concatenate(lam_data + [globe_data], axis=self.axis) |
| 300 | + return apply_index_to_slices_changes(result, changes) |
170 | 301 |
|
171 | 302 | def collect_supporting_arrays(self, collected, *path):
|
172 |
| - collected.append((path, "cutout_mask", self.mask)) |
| 303 | + """Collects supporting arrays, including masks for each LAM and the global |
| 304 | + dataset. |
| 305 | +
|
| 306 | + Args: |
| 307 | + collected (list): List to which the supporting arrays are appended. |
| 308 | + *path: Variable length argument list specifying the paths for the masks. |
| 309 | + """ |
| 310 | + # Append masks for each LAM |
| 311 | + for i, (lam, mask) in enumerate(zip(self.lams, self.masks)): |
| 312 | + collected.append((path + (f"lam_{i}",), "cutout_mask", mask)) |
| 313 | + |
| 314 | + # Append the global mask |
| 315 | + collected.append((path + ("global",), "cutout_mask", self.global_mask)) |
173 | 316 |
|
174 | 317 | @cached_property
|
175 | 318 | def shape(self):
|
176 |
| - shape = self.lam.shape |
177 |
| - # Number of non-zero masked values in the globe dataset |
178 |
| - nb_globe = np.count_nonzero(self.mask) |
179 |
| - return shape[:-1] + (shape[-1] + nb_globe,) |
| 319 | + """Returns the shape of the Cutout, accounting for retained grid points |
| 320 | + across all LAMs and the global dataset. |
| 321 | +
|
| 322 | + Returns: |
| 323 | + tuple: Shape of the concatenated masked datasets. |
| 324 | + """ |
| 325 | + shapes = [np.sum(mask) for mask in self.masks] |
| 326 | + global_shape = np.sum(self.global_mask) |
| 327 | + return tuple(self.lams[0].shape[:-1] + (sum(shapes) + global_shape,)) |
180 | 328 |
|
181 | 329 | def check_same_resolution(self, d1, d2):
|
182 | 330 | # Turned off because we are combining different resolutions
|
183 | 331 | pass
|
184 | 332 |
|
185 | 333 | @property
|
186 |
| - def latitudes(self): |
187 |
| - return np.concatenate([self.lam.latitudes, self.globe.latitudes[self.mask]]) |
| 334 | + def grids(self): |
| 335 | + """Returns the number of grid points for each LAM and the global dataset |
| 336 | + after applying masks. |
188 | 337 |
|
189 |
| - @property |
190 |
| - def longitudes(self): |
191 |
| - return np.concatenate([self.lam.longitudes, self.globe.longitudes[self.mask]]) |
| 338 | + Returns: |
| 339 | + tuple: Count of retained grid points for each dataset. |
| 340 | + """ |
| 341 | + grids = [np.sum(mask) for mask in self.masks] |
| 342 | + grids.append(np.sum(self.global_mask)) |
| 343 | + return tuple(grids) |
192 | 344 |
|
193 |
| - def __getitem__(self, index): |
194 |
| - if isinstance(index, (int, slice)): |
195 |
| - index = (index, slice(None), slice(None), slice(None)) |
196 |
| - return self._get_tuple(index) |
| 345 | + @property |
| 346 | + def latitudes(self): |
| 347 | + """Returns the concatenated latitudes of each LAM and the global dataset |
| 348 | + after applying masks. |
197 | 349 |
|
198 |
| - @debug_indexing |
199 |
| - @expand_list_indexing |
200 |
| - def _get_tuple(self, index): |
201 |
| - assert self.axis >= len(index) or index[self.axis] == slice( |
202 |
| - None |
203 |
| - ), f"No support for selecting a subset of the 1D values {index} ({self.tree()})" |
204 |
| - index, changes = index_to_slices(index, self.shape) |
| 350 | + Returns: |
| 351 | + np.ndarray: Concatenated latitude array for the masked datasets. |
| 352 | + """ |
| 353 | + lam_latitudes = np.concatenate([lam.latitudes[mask] for lam, mask in zip(self.lams, self.masks)]) |
205 | 354 |
|
206 |
| - # In case index_to_slices has changed the last slice |
207 |
| - index, _ = update_tuple(index, self.axis, slice(None)) |
| 355 | + assert ( |
| 356 | + len(lam_latitudes) + len(self.globe.latitudes[self.global_mask]) == self.shape[-1] |
| 357 | + ), "Mismatch in number of latitudes" |
208 | 358 |
|
209 |
| - lam_data = self.lam[index] |
210 |
| - globe_data = self.globe[index] |
| 359 | + latitudes = np.concatenate([lam_latitudes, self.globe.latitudes[self.global_mask]]) |
| 360 | + return latitudes |
211 | 361 |
|
212 |
| - globe_data = globe_data[:, :, :, self.mask] |
| 362 | + @property |
| 363 | + def longitudes(self): |
| 364 | + """Returns the concatenated longitudes of each LAM and the global dataset |
| 365 | + after applying masks. |
213 | 366 |
|
214 |
| - result = np.concatenate([lam_data, globe_data], axis=self.axis) |
| 367 | + Returns: |
| 368 | + np.ndarray: Concatenated longitude array for the masked datasets. |
| 369 | + """ |
| 370 | + lam_longitudes = np.concatenate([lam.longitudes[mask] for lam, mask in zip(self.lams, self.masks)]) |
215 | 371 |
|
216 |
| - return apply_index_to_slices_changes(result, changes) |
| 372 | + assert ( |
| 373 | + len(lam_longitudes) + len(self.globe.longitudes[self.global_mask]) == self.shape[-1] |
| 374 | + ), "Mismatch in number of longitudes" |
217 | 375 |
|
218 |
| - @property |
219 |
| - def grids(self): |
220 |
| - for d in self.datasets: |
221 |
| - if len(d.grids) > 1: |
222 |
| - raise NotImplementedError("CutoutGrids does not support multi-grids datasets as inputs") |
223 |
| - shape = self.lam.shape |
224 |
| - return (shape[-1], self.shape[-1] - shape[-1]) |
| 376 | + longitudes = np.concatenate([lam_longitudes, self.globe.longitudes[self.global_mask]]) |
| 377 | + return longitudes |
225 | 378 |
|
226 | 379 | def tree(self):
|
| 380 | + """Generates a hierarchical tree structure for the `Cutout` instance and |
| 381 | + its associated datasets. |
| 382 | +
|
| 383 | + Returns: |
| 384 | + Node: A `Node` object representing the `Cutout` instance as the root |
| 385 | + node, with each dataset in `self.datasets` represented as a child |
| 386 | + node. |
| 387 | + """ |
227 | 388 | return Node(self, [d.tree() for d in self.datasets])
|
228 | 389 |
|
229 |
| - # def metadata_specific(self): |
230 |
| - # return super().metadata_specific( |
231 |
| - # mask=serialise_mask(self.mask), |
232 |
| - # ) |
233 |
| - |
234 | 390 |
|
235 | 391 | def grids_factory(args, kwargs):
|
236 | 392 | if "ensemble" in kwargs:
|
|
0 commit comments