-
-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add upsample and add to windnet in production
- Loading branch information
1 parent
9e7b229
commit 77116af
Showing
4 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
"""Up Sample Xarray datasets Datapipe""" | ||
from torch.utils.data import IterDataPipe, functional_datapipe | ||
|
||
import logging | ||
import numpy as np | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
@functional_datapipe("upsample") | ||
class UpSampleIterDataPipe(IterDataPipe): | ||
"""Up Sample Xarray dataset with Interpolate""" | ||
|
||
def __init__( | ||
self, | ||
source_datapipe: IterDataPipe, | ||
y_upsample: int, | ||
x_upsample: int, | ||
x_dim_name: str = "longitude", | ||
y_dim_name: str = "latitude", | ||
keep_same_shape: bool = False, | ||
): | ||
""" | ||
Up Sample xarray dataset/dataarrays with interpolate | ||
Args: | ||
source_datapipe: Datapipe emitting Xarray dataset | ||
y_upsample: up sample value in the y direction | ||
x_upsample: Up sample value in the x direction | ||
x_dim_name: X dimension name | ||
y_dim_name: Y dimension name | ||
keep_same_shape: Optional to keep the same shape. Defaults to zero. | ||
If True, shape is trimmed around the edges. | ||
""" | ||
self.source_datapipe = source_datapipe | ||
self.y_upsample = y_upsample | ||
self.x_upsample = x_upsample | ||
self.x_dim_name = x_dim_name | ||
self.y_dim_name = y_dim_name | ||
self.keep_same_shape = keep_same_shape | ||
|
||
def __iter__(self): | ||
"""Coarsen the data on the specified dimensions""" | ||
for xr_data in self.source_datapipe: | ||
|
||
log.info("Up Sampling Data") | ||
print(xr_data) | ||
|
||
# get current x and y values | ||
current_x_dim_values = getattr(xr_data, self.x_dim_name).values | ||
current_y_dim_values = getattr(xr_data, self.y_dim_name).values | ||
|
||
# get current interval values | ||
current_x_interval = np.abs(current_x_dim_values[1] - current_x_dim_values[0]) | ||
current_y_interval = np.abs(current_y_dim_values[1] - current_y_dim_values[0]) | ||
|
||
# new intervals | ||
new_x_interval = current_x_interval / self.x_upsample | ||
new_y_interval = current_y_interval / self.y_upsample | ||
|
||
if self.keep_same_shape: | ||
# up sample the center of the image and keep the same shape as original image | ||
|
||
center_x = current_x_dim_values[int(len(current_x_dim_values) / 2)] | ||
center_y = current_y_dim_values[int(len(current_y_dim_values) / 2)] | ||
|
||
new_x_min = center_x - new_x_interval * int(len(current_x_dim_values) / 2) | ||
new_x_max = new_x_min + new_x_interval * (len(current_x_dim_values) - 1) | ||
|
||
new_y_min = center_y - new_y_interval * int(len(current_y_dim_values) / 2) | ||
new_y_max = new_y_min + new_y_interval * (len(current_y_dim_values) - 1) | ||
|
||
else: | ||
|
||
new_x_min = min(current_x_dim_values) | ||
new_x_max = max(current_x_dim_values) | ||
|
||
new_y_min = min(current_y_dim_values) | ||
new_y_max = max(current_y_dim_values) | ||
|
||
# get new x values | ||
new_x_dim_values = list( | ||
np.arange( | ||
new_x_min, | ||
new_x_max + new_x_interval, | ||
new_x_interval, | ||
) | ||
) | ||
|
||
# get new y values | ||
new_y_dim_values = list( | ||
np.arange( | ||
new_y_min, | ||
new_y_max + new_y_interval, | ||
new_y_interval, | ||
) | ||
) | ||
|
||
log.info( | ||
f"Up Sampling X from ({min(current_x_dim_values)}, {current_x_interval}, " | ||
f"{max(current_x_dim_values)}) to ({new_x_min}, {new_x_interval}, {new_x_max})" | ||
) | ||
log.info( | ||
f"Up Sampling Y from ({min(current_y_dim_values)}, {current_y_interval}, " | ||
f"{max(current_y_dim_values)}) to ({new_y_min}, {new_y_interval}, {new_y_max})" | ||
) | ||
|
||
# resample | ||
xr_data = xr_data.interp(**{self.x_dim_name: new_x_dim_values}) | ||
xr_data = xr_data.interp(**{self.y_dim_name: new_y_dim_values}) | ||
|
||
yield xr_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from ocf_datapipes.transform.xarray import Upsample | ||
|
||
|
||
def test_nwp_upsample(nwp_datapipe): | ||
nwp_datapipe = Upsample( | ||
nwp_datapipe, y_upsample=2, x_upsample=2, x_dim_name="x_osgb", y_dim_name="y_osgb" | ||
) | ||
data = next(iter(nwp_datapipe)) | ||
|
||
# Upsample by 2 from 704x548 | ||
assert data.shape[-1] == 548 * 2 - 1 | ||
assert data.shape[-2] == 704 * 2 - 1 | ||
|
||
|
||
def test_sat_downsample(sat_datapipe): | ||
sat_datapipe = Upsample( | ||
sat_datapipe, | ||
y_upsample=2, | ||
x_upsample=2, | ||
y_dim_name="y_geostationary", | ||
x_dim_name="x_geostationary", | ||
) | ||
data = next(iter(sat_datapipe)) | ||
assert data.shape[-1] == 615 * 2 | ||
assert data.shape[-2] == 298 * 2 - 1 | ||
|
||
|
||
def test_nwp_upsample_keep_same_shape(nwp_datapipe): | ||
nwp_datapipe_new, nwp_datapipe_old = nwp_datapipe.fork(2) | ||
|
||
nwp_datapipe_new = Upsample( | ||
nwp_datapipe_new, | ||
y_upsample=2, | ||
x_upsample=2, | ||
x_dim_name="x_osgb", | ||
y_dim_name="y_osgb", | ||
keep_same_shape=True, | ||
) | ||
data_new = next(iter(nwp_datapipe_new)) | ||
data_old = next(iter(nwp_datapipe_old)) | ||
|
||
assert data_new.shape[-1] == 548 | ||
assert data_new.shape[-2] == 704 | ||
|
||
# check first values are different | ||
assert data_new.x_osgb.values[0] != data_old.x_osgb.values[0] | ||
assert data_new.y_osgb.values[0] != data_old.y_osgb.values[0] | ||
|
||
# check middle valyes are the same | ||
assert data_new.x_osgb.values[274] == data_old.x_osgb.values[274] | ||
assert data_new.y_osgb.values[352] == data_old.y_osgb.values[352] |