Skip to content

Bugfix/overlay #22

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Keep it human-readable, your future self will thank you!

### Changed
- Change negative variance detection to make it less restrictive
- Fix cutout bug that left some global grid points in the lam part

### Removed

Expand Down
13 changes: 13 additions & 0 deletions docs/using/combining.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,16 @@ cutout:
:width: 75%
:align: center
:alt: Cutout

To debug the combination, you pass `plot=True` to the `cutout` function
(when running from a Notebook), of use `plot="prefix"` to save the plots
to series of PNG files in the current directory.

You can also pass a `min_distance_km` parameter to the `cutout`
function. Any grid points in the global dataset that are closer than
this distance to a grid point in the LAM dataset will be removed. This
can be useful to control the behaviour of the algorithm at the edge of
the cutout area. If no value is provided, the algorithm will compute its
value as the smallest distance between two grid points in the global
dataset over the cutout area. If you do not want to use this feature,
you can set `min_distance_km=0`.
17 changes: 14 additions & 3 deletions src/anemoi/datasets/data/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def tree(self):


class Cutout(GridsBase):
def __init__(self, datasets, axis):
def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, plot=False):
from anemoi.datasets.grids import cutout_mask

super().__init__(datasets, axis)
Expand All @@ -144,7 +144,9 @@ def __init__(self, datasets, axis):
self.lam.longitudes,
self.globe.latitudes,
self.globe.longitudes,
# plot="cutout",
plot=plot,
min_distance_km=min_distance_km,
cropping_distance=cropping_distance,
)
assert len(self.mask) == self.globe.shape[3], (
len(self.mask),
Expand Down Expand Up @@ -229,11 +231,20 @@ def cutout_factory(args, kwargs):

cutout = kwargs.pop("cutout")
axis = kwargs.pop("axis", 3)
plot = kwargs.pop("plot", None)
min_distance_km = kwargs.pop("min_distance_km", None)
cropping_distance = kwargs.pop("cropping_distance", 2.0)

assert len(args) == 0
assert isinstance(cutout, (list, tuple))

datasets = [_open(e) for e in cutout]
datasets, kwargs = _auto_adjust(datasets, kwargs)

return Cutout(datasets, axis=axis)._subset(**kwargs)
return Cutout(
datasets,
axis=axis,
min_distance_km=min_distance_km,
cropping_distance=cropping_distance,
plot=plot,
)._subset(**kwargs)
105 changes: 66 additions & 39 deletions src/anemoi/datasets/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,65 @@
# nor does it submit to any jurisdiction.
#

import logging

import numpy as np

LOG = logging.getLogger(__name__)


def plot_mask(path, mask, lats, lons, global_lats, global_lons):
import matplotlib.pyplot as plt

middle = (np.amin(lons) + np.amax(lons)) / 2
print("middle", middle)
s = 1

# gmiddle = (np.amin(global_lons)+ np.amax(global_lons))/2

# print('gmiddle', gmiddle)
# global_lons = global_lons-gmiddle+middle
global_lons[global_lons >= 180] -= 360

plt.figure(figsize=(10, 5))
plt.scatter(global_lons, global_lats, s=s, marker="o", c="r")
plt.savefig(path + "-global.png")
if isinstance(path, str):
plt.savefig(path + "-global.png")

plt.figure(figsize=(10, 5))
plt.scatter(global_lons[mask], global_lats[mask], s=s, c="k")
plt.savefig(path + "-cutout.png")
if isinstance(path, str):
plt.savefig(path + "-cutout.png")

plt.figure(figsize=(10, 5))
plt.scatter(lons, lats, s=s)
plt.savefig(path + "-lam.png")
if isinstance(path, str):
plt.savefig(path + "-lam.png")
# plt.scatter(lons, lats, s=0.01)

plt.figure(figsize=(10, 5))
plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r")
plt.scatter(lons, lats, s=s)
plt.savefig(path + "-both.png")
if isinstance(path, str):
plt.savefig(path + "-both.png")
# plt.scatter(lons, lats, s=0.01)

plt.figure(figsize=(10, 5))
plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r")
plt.scatter(lons, lats, s=s)
plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1)
plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1)
if isinstance(path, str):
plt.savefig(path + "-both-zoomed.png")

plt.figure(figsize=(10, 5))
plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r")
plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1)
plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1)
if isinstance(path, str):
plt.savefig(path + "-global-zoomed.png")


def xyz_to_latlon(x, y, z):
return (
np.rad2deg(np.arcsin(np.minimum(1.0, np.maximum(-1.0, z)))),
np.rad2deg(np.arctan2(y, x)),
)


def latlon_to_xyz(lat, lon, radius=1.0):
# https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates
Expand Down Expand Up @@ -126,6 +150,7 @@ def cutout_mask(
):
"""Return a mask for the points in [global_lats, global_lons] that are inside of [lats, lons]"""
from scipy.spatial import KDTree
from scipy.spatial import distance_matrix

# TODO: transform min_distance from lat/lon to xyz

Expand Down Expand Up @@ -166,56 +191,58 @@ def cutout_mask(

# Use a KDTree to find the nearest points
kdtree = KDTree(lam_points)
distances, indices = kdtree.query(global_points, k=3)
distances, indices = kdtree.query(global_points, k=4)

if min_distance_km is not None:
min_distance = min_distance_km / 6371.0
else:
# Estimnation of the minimum distance between two grib points

glats = sorted(set(global_lats_masked))
glons = sorted(set(global_lons_masked))
min_dlats = np.min(np.diff(glats))
min_dlons = np.min(np.diff(glons))

# Use the centre of the LAM grid as the reference point
centre = np.mean(lats), np.mean(lons)
centre_xyz = np.array(latlon_to_xyz(*centre))

pt1 = np.array(latlon_to_xyz(centre[0] + min_dlats, centre[1]))
pt2 = np.array(latlon_to_xyz(centre[0], centre[1] + min_dlons))
min_distance = (
min(
np.linalg.norm(pt1 - centre_xyz),
np.linalg.norm(pt2 - centre_xyz),
)
/ 2.0
)
min_distance = 0
dm = distance_matrix(global_points, global_points)
min_distance = np.min(dm[dm > 0])

LOG.debug(f"cutout_mask using min_distance = {min_distance * 6371.0} km")

# Centre of the Earth
zero = np.array([0.0, 0.0, 0.0])
ok = []

# After the loop, 'inside_lam' will contain a list point to EXCLUDE
inside_lam = []

for i, (global_point, distance, index) in enumerate(zip(global_points, distances, indices)):
t = Triangle3D(lam_points[index[0]], lam_points[index[1]], lam_points[index[2]])
# distance = np.min(distance)

# We check more than one triangle in case te global point
# is near the edge of triangle, (the lam point and global points are colinear)

t1 = Triangle3D(lam_points[index[0]], lam_points[index[1]], lam_points[index[2]])
t2 = Triangle3D(lam_points[index[1]], lam_points[index[2]], lam_points[index[3]])
t3 = Triangle3D(lam_points[index[2]], lam_points[index[3]], lam_points[index[0]])
t4 = Triangle3D(lam_points[index[3]], lam_points[index[0]], lam_points[index[1]])

# The point is inside the triangle if the intersection with the ray
# from the point to the centre of the Earth is not None
# (the direction of the ray is not important)

intersect = t.intersect(zero, global_point)
intersect = (
t1.intersect(zero, global_point)
or t2.intersect(zero, global_point)
or t3.intersect(zero, global_point)
or t4.intersect(zero, global_point)
)

close = np.min(distance) <= min_distance

ok.append(intersect or close)
inside_lam.append(intersect or close)

j = 0
ok = np.array(ok)
inside_lam = np.array(inside_lam)
for i, m in enumerate(mask):
if not m:
continue

mask[i] = ok[j]
mask[i] = inside_lam[j]
j += 1

assert j == len(ok)
assert j == len(inside_lam)

# Invert the mask, so we have only the points outside the cutout
mask = ~mask
Expand Down
Loading