Skip to content

Commit

Permalink
add upsample and add to windnet in production
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Mar 4, 2024
1 parent 9e7b229 commit 77116af
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ocf_datapipes/training/windnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ def construct_sliced_data_pipeline(
std=NWP_STDS[conf_nwp[nwp_key].nwp_provider],
)

if production:
nwp_datapipes_dict[nwp_key] = nwp_datapipes_dict[nwp_key].upsample(

Check warning on line 243 in ocf_datapipes/training/windnet.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/windnet.py#L243

Added line #L243 was not covered by tests
y_upsample=2, x_upsample=2, keep_same_shape=True
)

if "sat" in datapipes_dict:
sat_datapipe = datapipes_dict["sat"]

Expand Down
1 change: 1 addition & 0 deletions ocf_datapipes/transform/xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .create_sun_image import CreateSunImageIterDataPipe as CreateSunImage
from .create_time_image import CreateTimeImageIterDataPipe as CreateTimeImage
from .downsample import DownsampleIterDataPipe as Downsample
from .upsample import UpSampleIterDataPipe as Upsample
from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage
from .gsp.ensure_n_gsp_per_example import (
EnsureNGSPSPerExampleIterDataPipe as EnsureNGSPSPerExampleIter,
Expand Down
112 changes: 112 additions & 0 deletions ocf_datapipes/transform/xarray/upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Up Sample Xarray datasets Datapipe"""
from torch.utils.data import IterDataPipe, functional_datapipe

import logging
import numpy as np

log = logging.getLogger(__name__)


@functional_datapipe("upsample")
class UpSampleIterDataPipe(IterDataPipe):
"""Up Sample Xarray dataset with Interpolate"""

def __init__(
self,
source_datapipe: IterDataPipe,
y_upsample: int,
x_upsample: int,
x_dim_name: str = "longitude",
y_dim_name: str = "latitude",
keep_same_shape: bool = False,
):
"""
Up Sample xarray dataset/dataarrays with interpolate
Args:
source_datapipe: Datapipe emitting Xarray dataset
y_upsample: up sample value in the y direction
x_upsample: Up sample value in the x direction
x_dim_name: X dimension name
y_dim_name: Y dimension name
keep_same_shape: Optional to keep the same shape. Defaults to zero.
If True, shape is trimmed around the edges.
"""
self.source_datapipe = source_datapipe
self.y_upsample = y_upsample
self.x_upsample = x_upsample
self.x_dim_name = x_dim_name
self.y_dim_name = y_dim_name
self.keep_same_shape = keep_same_shape

def __iter__(self):
"""Coarsen the data on the specified dimensions"""
for xr_data in self.source_datapipe:

log.info("Up Sampling Data")
print(xr_data)

# get current x and y values
current_x_dim_values = getattr(xr_data, self.x_dim_name).values
current_y_dim_values = getattr(xr_data, self.y_dim_name).values

# get current interval values
current_x_interval = np.abs(current_x_dim_values[1] - current_x_dim_values[0])
current_y_interval = np.abs(current_y_dim_values[1] - current_y_dim_values[0])

# new intervals
new_x_interval = current_x_interval / self.x_upsample
new_y_interval = current_y_interval / self.y_upsample

if self.keep_same_shape:
# up sample the center of the image and keep the same shape as original image

center_x = current_x_dim_values[int(len(current_x_dim_values) / 2)]
center_y = current_y_dim_values[int(len(current_y_dim_values) / 2)]

new_x_min = center_x - new_x_interval * int(len(current_x_dim_values) / 2)
new_x_max = new_x_min + new_x_interval * (len(current_x_dim_values) - 1)

new_y_min = center_y - new_y_interval * int(len(current_y_dim_values) / 2)
new_y_max = new_y_min + new_y_interval * (len(current_y_dim_values) - 1)

else:

new_x_min = min(current_x_dim_values)
new_x_max = max(current_x_dim_values)

new_y_min = min(current_y_dim_values)
new_y_max = max(current_y_dim_values)

# get new x values
new_x_dim_values = list(
np.arange(
new_x_min,
new_x_max + new_x_interval,
new_x_interval,
)
)

# get new y values
new_y_dim_values = list(
np.arange(
new_y_min,
new_y_max + new_y_interval,
new_y_interval,
)
)

log.info(
f"Up Sampling X from ({min(current_x_dim_values)}, {current_x_interval}, "
f"{max(current_x_dim_values)}) to ({new_x_min}, {new_x_interval}, {new_x_max})"
)
log.info(
f"Up Sampling Y from ({min(current_y_dim_values)}, {current_y_interval}, "
f"{max(current_y_dim_values)}) to ({new_y_min}, {new_y_interval}, {new_y_max})"
)

# resample
xr_data = xr_data.interp(**{self.x_dim_name: new_x_dim_values})
xr_data = xr_data.interp(**{self.y_dim_name: new_y_dim_values})

yield xr_data
51 changes: 51 additions & 0 deletions tests/transform/xarray/test_upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from ocf_datapipes.transform.xarray import Upsample


def test_nwp_upsample(nwp_datapipe):
nwp_datapipe = Upsample(
nwp_datapipe, y_upsample=2, x_upsample=2, x_dim_name="x_osgb", y_dim_name="y_osgb"
)
data = next(iter(nwp_datapipe))

# Upsample by 2 from 704x548
assert data.shape[-1] == 548 * 2 - 1
assert data.shape[-2] == 704 * 2 - 1


def test_sat_downsample(sat_datapipe):
sat_datapipe = Upsample(
sat_datapipe,
y_upsample=2,
x_upsample=2,
y_dim_name="y_geostationary",
x_dim_name="x_geostationary",
)
data = next(iter(sat_datapipe))
assert data.shape[-1] == 615 * 2
assert data.shape[-2] == 298 * 2 - 1


def test_nwp_upsample_keep_same_shape(nwp_datapipe):
nwp_datapipe_new, nwp_datapipe_old = nwp_datapipe.fork(2)

nwp_datapipe_new = Upsample(
nwp_datapipe_new,
y_upsample=2,
x_upsample=2,
x_dim_name="x_osgb",
y_dim_name="y_osgb",
keep_same_shape=True,
)
data_new = next(iter(nwp_datapipe_new))
data_old = next(iter(nwp_datapipe_old))

assert data_new.shape[-1] == 548
assert data_new.shape[-2] == 704

# check first values are different
assert data_new.x_osgb.values[0] != data_old.x_osgb.values[0]
assert data_new.y_osgb.values[0] != data_old.y_osgb.values[0]

# check middle valyes are the same
assert data_new.x_osgb.values[274] == data_old.x_osgb.values[274]
assert data_new.y_osgb.values[352] == data_old.y_osgb.values[352]

0 comments on commit 77116af

Please sign in to comment.