diff --git a/ocf_datapipes/training/windnet.py b/ocf_datapipes/training/windnet.py index ea5f0a06c..8b9231405 100644 --- a/ocf_datapipes/training/windnet.py +++ b/ocf_datapipes/training/windnet.py @@ -239,6 +239,11 @@ def construct_sliced_data_pipeline( std=NWP_STDS[conf_nwp[nwp_key].nwp_provider], ) + if production: + nwp_datapipes_dict[nwp_key] = nwp_datapipes_dict[nwp_key].upsample( + y_upsample=2, x_upsample=2, keep_same_shape=True + ) + if "sat" in datapipes_dict: sat_datapipe = datapipes_dict["sat"] diff --git a/ocf_datapipes/transform/xarray/__init__.py b/ocf_datapipes/transform/xarray/__init__.py index cf00a2c72..d6b625a87 100644 --- a/ocf_datapipes/transform/xarray/__init__.py +++ b/ocf_datapipes/transform/xarray/__init__.py @@ -8,6 +8,7 @@ from .create_sun_image import CreateSunImageIterDataPipe as CreateSunImage from .create_time_image import CreateTimeImageIterDataPipe as CreateTimeImage from .downsample import DownsampleIterDataPipe as Downsample +from .upsample import UpSampleIterDataPipe as Upsample from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage from .gsp.ensure_n_gsp_per_example import ( EnsureNGSPSPerExampleIterDataPipe as EnsureNGSPSPerExampleIter, diff --git a/ocf_datapipes/transform/xarray/upsample.py b/ocf_datapipes/transform/xarray/upsample.py new file mode 100644 index 000000000..4134068dd --- /dev/null +++ b/ocf_datapipes/transform/xarray/upsample.py @@ -0,0 +1,112 @@ +"""Up Sample Xarray datasets Datapipe""" +from torch.utils.data import IterDataPipe, functional_datapipe + +import logging +import numpy as np + +log = logging.getLogger(__name__) + + +@functional_datapipe("upsample") +class UpSampleIterDataPipe(IterDataPipe): + """Up Sample Xarray dataset with Interpolate""" + + def __init__( + self, + source_datapipe: IterDataPipe, + y_upsample: int, + x_upsample: int, + x_dim_name: str = "longitude", + y_dim_name: str = "latitude", + keep_same_shape: bool = False, + ): + """ + Up Sample xarray dataset/dataarrays with interpolate + + Args: + source_datapipe: Datapipe emitting Xarray dataset + y_upsample: up sample value in the y direction + x_upsample: Up sample value in the x direction + x_dim_name: X dimension name + y_dim_name: Y dimension name + keep_same_shape: Optional to keep the same shape. Defaults to zero. + If True, shape is trimmed around the edges. + """ + self.source_datapipe = source_datapipe + self.y_upsample = y_upsample + self.x_upsample = x_upsample + self.x_dim_name = x_dim_name + self.y_dim_name = y_dim_name + self.keep_same_shape = keep_same_shape + + def __iter__(self): + """Coarsen the data on the specified dimensions""" + for xr_data in self.source_datapipe: + + log.info("Up Sampling Data") + print(xr_data) + + # get current x and y values + current_x_dim_values = getattr(xr_data, self.x_dim_name).values + current_y_dim_values = getattr(xr_data, self.y_dim_name).values + + # get current interval values + current_x_interval = np.abs(current_x_dim_values[1] - current_x_dim_values[0]) + current_y_interval = np.abs(current_y_dim_values[1] - current_y_dim_values[0]) + + # new intervals + new_x_interval = current_x_interval / self.x_upsample + new_y_interval = current_y_interval / self.y_upsample + + if self.keep_same_shape: + # up sample the center of the image and keep the same shape as original image + + center_x = current_x_dim_values[int(len(current_x_dim_values) / 2)] + center_y = current_y_dim_values[int(len(current_y_dim_values) / 2)] + + new_x_min = center_x - new_x_interval * int(len(current_x_dim_values) / 2) + new_x_max = new_x_min + new_x_interval * (len(current_x_dim_values) - 1) + + new_y_min = center_y - new_y_interval * int(len(current_y_dim_values) / 2) + new_y_max = new_y_min + new_y_interval * (len(current_y_dim_values) - 1) + + else: + + new_x_min = min(current_x_dim_values) + new_x_max = max(current_x_dim_values) + + new_y_min = min(current_y_dim_values) + new_y_max = max(current_y_dim_values) + + # get new x values + new_x_dim_values = list( + np.arange( + new_x_min, + new_x_max + new_x_interval, + new_x_interval, + ) + ) + + # get new y values + new_y_dim_values = list( + np.arange( + new_y_min, + new_y_max + new_y_interval, + new_y_interval, + ) + ) + + log.info( + f"Up Sampling X from ({min(current_x_dim_values)}, {current_x_interval}, " + f"{max(current_x_dim_values)}) to ({new_x_min}, {new_x_interval}, {new_x_max})" + ) + log.info( + f"Up Sampling Y from ({min(current_y_dim_values)}, {current_y_interval}, " + f"{max(current_y_dim_values)}) to ({new_y_min}, {new_y_interval}, {new_y_max})" + ) + + # resample + xr_data = xr_data.interp(**{self.x_dim_name: new_x_dim_values}) + xr_data = xr_data.interp(**{self.y_dim_name: new_y_dim_values}) + + yield xr_data diff --git a/tests/transform/xarray/test_upsample.py b/tests/transform/xarray/test_upsample.py new file mode 100644 index 000000000..74faf5d68 --- /dev/null +++ b/tests/transform/xarray/test_upsample.py @@ -0,0 +1,51 @@ +from ocf_datapipes.transform.xarray import Upsample + + +def test_nwp_upsample(nwp_datapipe): + nwp_datapipe = Upsample( + nwp_datapipe, y_upsample=2, x_upsample=2, x_dim_name="x_osgb", y_dim_name="y_osgb" + ) + data = next(iter(nwp_datapipe)) + + # Upsample by 2 from 704x548 + assert data.shape[-1] == 548 * 2 - 1 + assert data.shape[-2] == 704 * 2 - 1 + + +def test_sat_downsample(sat_datapipe): + sat_datapipe = Upsample( + sat_datapipe, + y_upsample=2, + x_upsample=2, + y_dim_name="y_geostationary", + x_dim_name="x_geostationary", + ) + data = next(iter(sat_datapipe)) + assert data.shape[-1] == 615 * 2 + assert data.shape[-2] == 298 * 2 - 1 + + +def test_nwp_upsample_keep_same_shape(nwp_datapipe): + nwp_datapipe_new, nwp_datapipe_old = nwp_datapipe.fork(2) + + nwp_datapipe_new = Upsample( + nwp_datapipe_new, + y_upsample=2, + x_upsample=2, + x_dim_name="x_osgb", + y_dim_name="y_osgb", + keep_same_shape=True, + ) + data_new = next(iter(nwp_datapipe_new)) + data_old = next(iter(nwp_datapipe_old)) + + assert data_new.shape[-1] == 548 + assert data_new.shape[-2] == 704 + + # check first values are different + assert data_new.x_osgb.values[0] != data_old.x_osgb.values[0] + assert data_new.y_osgb.values[0] != data_old.y_osgb.values[0] + + # check middle valyes are the same + assert data_new.x_osgb.values[274] == data_old.x_osgb.values[274] + assert data_new.y_osgb.values[352] == data_old.y_osgb.values[352]