Skip to content

Commit a795999

Browse files
committed
better control on cutout
1 parent 6d45f38 commit a795999

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

src/anemoi/datasets/data/grids.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def tree(self):
128128

129129

130130
class Cutout(GridsBase):
131-
def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, plot=False):
131+
def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False):
132132
from anemoi.datasets.grids import cutout_mask
133133

134134
super().__init__(datasets, axis)
@@ -147,6 +147,7 @@ def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0,
147147
plot=plot,
148148
min_distance_km=min_distance_km,
149149
cropping_distance=cropping_distance,
150+
neighbours=neighbours,
150151
)
151152
assert len(self.mask) == self.globe.shape[3], (
152153
len(self.mask),
@@ -234,6 +235,7 @@ def cutout_factory(args, kwargs):
234235
plot = kwargs.pop("plot", None)
235236
min_distance_km = kwargs.pop("min_distance_km", None)
236237
cropping_distance = kwargs.pop("cropping_distance", 2.0)
238+
neighbours = kwargs.pop("neighbours", 5)
237239

238240
assert len(args) == 0
239241
assert isinstance(cutout, (list, tuple))
@@ -244,6 +246,7 @@ def cutout_factory(args, kwargs):
244246
return Cutout(
245247
datasets,
246248
axis=axis,
249+
neighbours=neighbours,
247250
min_distance_km=min_distance_km,
248251
cropping_distance=cropping_distance,
249252
plot=plot,

src/anemoi/datasets/grids.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def cutout_mask(
145145
global_lats,
146146
global_lons,
147147
cropping_distance=2.0,
148+
neighbours=5,
148149
min_distance_km=None,
149150
plot=None,
150151
):
@@ -191,7 +192,8 @@ def cutout_mask(
191192

192193
# Use a KDTree to find the nearest points
193194
kdtree = KDTree(lam_points)
194-
distances, indices = kdtree.query(global_points, k=4)
195+
196+
distances, indices = kdtree.query(global_points, k=neighbours)
195197

196198
if min_distance_km is not None:
197199
min_distance = min_distance_km / 6371.0
@@ -213,25 +215,18 @@ def cutout_mask(
213215
# We check more than one triangle in case te global point
214216
# is near the edge of triangle, (the lam point and global points are colinear)
215217

216-
t1 = Triangle3D(lam_points[index[0]], lam_points[index[1]], lam_points[index[2]])
217-
t2 = Triangle3D(lam_points[index[1]], lam_points[index[2]], lam_points[index[3]])
218-
t3 = Triangle3D(lam_points[index[2]], lam_points[index[3]], lam_points[index[0]])
219-
t4 = Triangle3D(lam_points[index[3]], lam_points[index[0]], lam_points[index[1]])
220-
221-
# The point is inside the triangle if the intersection with the ray
222-
# from the point to the centre of the Earth is not None
223-
# (the direction of the ray is not important)
224-
225-
intersect = (
226-
t1.intersect(zero, global_point)
227-
or t2.intersect(zero, global_point)
228-
or t3.intersect(zero, global_point)
229-
or t4.intersect(zero, global_point)
230-
)
218+
inside = False
219+
for j in range(neighbours):
220+
t = Triangle3D(
221+
lam_points[index[j]], lam_points[index[(j + 1) % neighbours]], lam_points[index[(j + 2) % neighbours]]
222+
)
223+
inside = t.intersect(zero, global_point)
224+
if inside:
225+
break
231226

232227
close = np.min(distance) <= min_distance
233228

234-
inside_lam.append(intersect or close)
229+
inside_lam.append(inside or close)
235230

236231
j = 0
237232
inside_lam = np.array(inside_lam)

0 commit comments

Comments
 (0)