From 941bbf6822d174e0991a96f12589f15eb75db53c Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 30 Aug 2024 17:54:28 +0100 Subject: [PATCH] faster min-distance calculation --- src/anemoi/datasets/grids.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py index 3aac8db4..be7eb6b8 100644 --- a/src/anemoi/datasets/grids.py +++ b/src/anemoi/datasets/grids.py @@ -189,17 +189,15 @@ def cutout_mask( xyx = latlon_to_xyz(lats, lons) lam_points = np.array(xyx).transpose() - # Use a KDTree to find the nearest points - kdtree = KDTree(lam_points) - - distances, indices = kdtree.query(global_points, k=neighbours) - if min_distance_km is not None: min_distance = min_distance_km / 6371.0 else: distances, _ = KDTree(global_points).query(global_points, k=2) min_distance = np.min(distances[:, 1]) + # Use a KDTree to find the nearest points + distances, indices = KDTree(lam_points).query(global_points, k=neighbours) + LOG.debug(f"cutout_mask using min_distance = {min_distance * 6371.0} km") # Centre of the Earth @@ -291,8 +289,7 @@ def thinning_mask( points = np.array(xyx).transpose() # Use a KDTree to find the nearest points - kdtree = KDTree(points) - _, indices = kdtree.query(global_points, k=1) + _, indices = KDTree(points).query(global_points, k=1) return np.array([i for i in indices])