Skip to content

Commit

Permalink
better control on cutout
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Aug 29, 2024
1 parent 6d45f38 commit a795999
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
5 changes: 4 additions & 1 deletion src/anemoi/datasets/data/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def tree(self):


class Cutout(GridsBase):
def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, plot=False):
def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False):
from anemoi.datasets.grids import cutout_mask

super().__init__(datasets, axis)
Expand All @@ -147,6 +147,7 @@ def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0,
plot=plot,
min_distance_km=min_distance_km,
cropping_distance=cropping_distance,
neighbours=neighbours,
)
assert len(self.mask) == self.globe.shape[3], (
len(self.mask),
Expand Down Expand Up @@ -234,6 +235,7 @@ def cutout_factory(args, kwargs):
plot = kwargs.pop("plot", None)
min_distance_km = kwargs.pop("min_distance_km", None)
cropping_distance = kwargs.pop("cropping_distance", 2.0)
neighbours = kwargs.pop("neighbours", 5)

assert len(args) == 0
assert isinstance(cutout, (list, tuple))
Expand All @@ -244,6 +246,7 @@ def cutout_factory(args, kwargs):
return Cutout(
datasets,
axis=axis,
neighbours=neighbours,
min_distance_km=min_distance_km,
cropping_distance=cropping_distance,
plot=plot,
Expand Down
29 changes: 12 additions & 17 deletions src/anemoi/datasets/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def cutout_mask(
global_lats,
global_lons,
cropping_distance=2.0,
neighbours=5,
min_distance_km=None,
plot=None,
):
Expand Down Expand Up @@ -191,7 +192,8 @@ def cutout_mask(

# Use a KDTree to find the nearest points
kdtree = KDTree(lam_points)
distances, indices = kdtree.query(global_points, k=4)

distances, indices = kdtree.query(global_points, k=neighbours)

if min_distance_km is not None:
min_distance = min_distance_km / 6371.0
Expand All @@ -213,25 +215,18 @@ def cutout_mask(
# We check more than one triangle in case te global point
# is near the edge of triangle, (the lam point and global points are colinear)

t1 = Triangle3D(lam_points[index[0]], lam_points[index[1]], lam_points[index[2]])
t2 = Triangle3D(lam_points[index[1]], lam_points[index[2]], lam_points[index[3]])
t3 = Triangle3D(lam_points[index[2]], lam_points[index[3]], lam_points[index[0]])
t4 = Triangle3D(lam_points[index[3]], lam_points[index[0]], lam_points[index[1]])

# The point is inside the triangle if the intersection with the ray
# from the point to the centre of the Earth is not None
# (the direction of the ray is not important)

intersect = (
t1.intersect(zero, global_point)
or t2.intersect(zero, global_point)
or t3.intersect(zero, global_point)
or t4.intersect(zero, global_point)
)
inside = False
for j in range(neighbours):
t = Triangle3D(
lam_points[index[j]], lam_points[index[(j + 1) % neighbours]], lam_points[index[(j + 2) % neighbours]]
)
inside = t.intersect(zero, global_point)
if inside:
break

close = np.min(distance) <= min_distance

inside_lam.append(intersect or close)
inside_lam.append(inside or close)

j = 0
inside_lam = np.array(inside_lam)
Expand Down

0 comments on commit a795999

Please sign in to comment.