Skip to content

Commit a8666fa

Browse files
fix: complement with max distance (#469)
## Description - max_distance could not be passed as an argument to complement - complement with max_distance did not work as the nearest point is the shape of the source data when distance is inf (CKDTree) ## What problem does this change solve? Make complement usable with max_distance ## What issue or task does this change relate to? ## Additional notes ## ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 42b9e27 commit a8666fa

File tree

1 file changed

+26
-17
lines changed

1 file changed

+26
-17
lines changed

src/anemoi/datasets/data/complement.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -293,21 +293,29 @@ def _get_tuple(self, index: TupleIndex) -> NDArray[Any]:
293293
index, previous = update_tuple(index, variable_index, slice(None))
294294
source_index = [self._source.name_to_index[x] for x in self.variables[previous]]
295295
source_data = self._source[index[0], source_index, index[2], ...]
296-
target_data = source_data[..., self._nearest_grid_points]
297-
298-
epsilon = 1e-8 # prevent division by zero
299-
weights = 1.0 / (self._distances + epsilon)
300-
weights = weights.astype(target_data.dtype)
301-
weights /= weights.sum(axis=1, keepdims=True) # normalize
302-
303-
# Reshape weights to broadcast correctly
304-
# Add leading singleton dimensions so it matches target_data shape
305-
while weights.ndim < target_data.ndim:
306-
weights = np.expand_dims(weights, axis=0)
307-
308-
# Compute weighted average along the last dimension
309-
final_point = np.sum(target_data * weights, axis=-1)
310-
result = final_point[..., index[3]]
296+
if any(self._nearest_grid_points >= source_data.shape[-1]):
297+
target_shape = source_data.shape[:-1] + self._target.shape[-1:]
298+
target_data = np.full(target_shape, np.nan, dtype=self._target.dtype)
299+
cond = self._nearest_grid_points < source_data.shape[-1]
300+
reachable = np.where(cond)[0]
301+
nearest_reachable = self._nearest_grid_points[cond]
302+
target_data[..., reachable] = source_data[..., nearest_reachable]
303+
result = target_data[..., index[3]]
304+
else:
305+
target_data = source_data[..., self._nearest_grid_points]
306+
epsilon = 1e-8 # prevent division by zero
307+
weights = 1.0 / (self._distances + epsilon)
308+
weights = weights.astype(target_data.dtype)
309+
weights /= weights.sum(axis=1, keepdims=True) # normalize
310+
311+
# Reshape weights to broadcast correctly
312+
# Add leading singleton dimensions so it matches target_data shape
313+
while weights.ndim < target_data.ndim:
314+
weights = np.expand_dims(weights, axis=0)
315+
316+
# Compute weighted average along the last dimension
317+
final_point = np.sum(target_data * weights, axis=-1)
318+
result = final_point[..., index[3]]
311319

312320
return apply_index_to_slices_changes(result, changes)
313321

@@ -353,8 +361,9 @@ def complement_factory(args: tuple, kwargs: dict) -> Dataset:
353361
}[interpolation]
354362

355363
if interpolation == "nearest":
356-
k = kwargs.pop("k", "1")
357-
complement = Class(target=target, source=source, k=k)._subset(**kwargs)
364+
k = kwargs.pop("k", 1)
365+
max_distance = kwargs.pop("max_distance", None)
366+
complement = Class(target=target, source=source, k=k, max_distance=max_distance)._subset(**kwargs)
358367

359368
else:
360369
complement = Class(target=target, source=source)._subset(**kwargs)

0 commit comments

Comments
 (0)