Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 77116af

Browse files
committed
add upsample and add to windnet in production
1 parent 9e7b229 commit 77116af

File tree

4 files changed

+169
-0
lines changed

4 files changed

+169
-0
lines changed

ocf_datapipes/training/windnet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ def construct_sliced_data_pipeline(
239239
std=NWP_STDS[conf_nwp[nwp_key].nwp_provider],
240240
)
241241

242+
if production:
243+
nwp_datapipes_dict[nwp_key] = nwp_datapipes_dict[nwp_key].upsample(
244+
y_upsample=2, x_upsample=2, keep_same_shape=True
245+
)
246+
242247
if "sat" in datapipes_dict:
243248
sat_datapipe = datapipes_dict["sat"]
244249

ocf_datapipes/transform/xarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .create_sun_image import CreateSunImageIterDataPipe as CreateSunImage
99
from .create_time_image import CreateTimeImageIterDataPipe as CreateTimeImage
1010
from .downsample import DownsampleIterDataPipe as Downsample
11+
from .upsample import UpSampleIterDataPipe as Upsample
1112
from .gsp.create_gsp_image import CreateGSPImageIterDataPipe as CreateGSPImage
1213
from .gsp.ensure_n_gsp_per_example import (
1314
EnsureNGSPSPerExampleIterDataPipe as EnsureNGSPSPerExampleIter,
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Up Sample Xarray datasets Datapipe"""
2+
from torch.utils.data import IterDataPipe, functional_datapipe
3+
4+
import logging
5+
import numpy as np
6+
7+
log = logging.getLogger(__name__)
8+
9+
10+
@functional_datapipe("upsample")
11+
class UpSampleIterDataPipe(IterDataPipe):
12+
"""Up Sample Xarray dataset with Interpolate"""
13+
14+
def __init__(
15+
self,
16+
source_datapipe: IterDataPipe,
17+
y_upsample: int,
18+
x_upsample: int,
19+
x_dim_name: str = "longitude",
20+
y_dim_name: str = "latitude",
21+
keep_same_shape: bool = False,
22+
):
23+
"""
24+
Up Sample xarray dataset/dataarrays with interpolate
25+
26+
Args:
27+
source_datapipe: Datapipe emitting Xarray dataset
28+
y_upsample: up sample value in the y direction
29+
x_upsample: Up sample value in the x direction
30+
x_dim_name: X dimension name
31+
y_dim_name: Y dimension name
32+
keep_same_shape: Optional to keep the same shape. Defaults to zero.
33+
If True, shape is trimmed around the edges.
34+
"""
35+
self.source_datapipe = source_datapipe
36+
self.y_upsample = y_upsample
37+
self.x_upsample = x_upsample
38+
self.x_dim_name = x_dim_name
39+
self.y_dim_name = y_dim_name
40+
self.keep_same_shape = keep_same_shape
41+
42+
def __iter__(self):
43+
"""Coarsen the data on the specified dimensions"""
44+
for xr_data in self.source_datapipe:
45+
46+
log.info("Up Sampling Data")
47+
print(xr_data)
48+
49+
# get current x and y values
50+
current_x_dim_values = getattr(xr_data, self.x_dim_name).values
51+
current_y_dim_values = getattr(xr_data, self.y_dim_name).values
52+
53+
# get current interval values
54+
current_x_interval = np.abs(current_x_dim_values[1] - current_x_dim_values[0])
55+
current_y_interval = np.abs(current_y_dim_values[1] - current_y_dim_values[0])
56+
57+
# new intervals
58+
new_x_interval = current_x_interval / self.x_upsample
59+
new_y_interval = current_y_interval / self.y_upsample
60+
61+
if self.keep_same_shape:
62+
# up sample the center of the image and keep the same shape as original image
63+
64+
center_x = current_x_dim_values[int(len(current_x_dim_values) / 2)]
65+
center_y = current_y_dim_values[int(len(current_y_dim_values) / 2)]
66+
67+
new_x_min = center_x - new_x_interval * int(len(current_x_dim_values) / 2)
68+
new_x_max = new_x_min + new_x_interval * (len(current_x_dim_values) - 1)
69+
70+
new_y_min = center_y - new_y_interval * int(len(current_y_dim_values) / 2)
71+
new_y_max = new_y_min + new_y_interval * (len(current_y_dim_values) - 1)
72+
73+
else:
74+
75+
new_x_min = min(current_x_dim_values)
76+
new_x_max = max(current_x_dim_values)
77+
78+
new_y_min = min(current_y_dim_values)
79+
new_y_max = max(current_y_dim_values)
80+
81+
# get new x values
82+
new_x_dim_values = list(
83+
np.arange(
84+
new_x_min,
85+
new_x_max + new_x_interval,
86+
new_x_interval,
87+
)
88+
)
89+
90+
# get new y values
91+
new_y_dim_values = list(
92+
np.arange(
93+
new_y_min,
94+
new_y_max + new_y_interval,
95+
new_y_interval,
96+
)
97+
)
98+
99+
log.info(
100+
f"Up Sampling X from ({min(current_x_dim_values)}, {current_x_interval}, "
101+
f"{max(current_x_dim_values)}) to ({new_x_min}, {new_x_interval}, {new_x_max})"
102+
)
103+
log.info(
104+
f"Up Sampling Y from ({min(current_y_dim_values)}, {current_y_interval}, "
105+
f"{max(current_y_dim_values)}) to ({new_y_min}, {new_y_interval}, {new_y_max})"
106+
)
107+
108+
# resample
109+
xr_data = xr_data.interp(**{self.x_dim_name: new_x_dim_values})
110+
xr_data = xr_data.interp(**{self.y_dim_name: new_y_dim_values})
111+
112+
yield xr_data
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from ocf_datapipes.transform.xarray import Upsample
2+
3+
4+
def test_nwp_upsample(nwp_datapipe):
5+
nwp_datapipe = Upsample(
6+
nwp_datapipe, y_upsample=2, x_upsample=2, x_dim_name="x_osgb", y_dim_name="y_osgb"
7+
)
8+
data = next(iter(nwp_datapipe))
9+
10+
# Upsample by 2 from 704x548
11+
assert data.shape[-1] == 548 * 2 - 1
12+
assert data.shape[-2] == 704 * 2 - 1
13+
14+
15+
def test_sat_downsample(sat_datapipe):
16+
sat_datapipe = Upsample(
17+
sat_datapipe,
18+
y_upsample=2,
19+
x_upsample=2,
20+
y_dim_name="y_geostationary",
21+
x_dim_name="x_geostationary",
22+
)
23+
data = next(iter(sat_datapipe))
24+
assert data.shape[-1] == 615 * 2
25+
assert data.shape[-2] == 298 * 2 - 1
26+
27+
28+
def test_nwp_upsample_keep_same_shape(nwp_datapipe):
29+
nwp_datapipe_new, nwp_datapipe_old = nwp_datapipe.fork(2)
30+
31+
nwp_datapipe_new = Upsample(
32+
nwp_datapipe_new,
33+
y_upsample=2,
34+
x_upsample=2,
35+
x_dim_name="x_osgb",
36+
y_dim_name="y_osgb",
37+
keep_same_shape=True,
38+
)
39+
data_new = next(iter(nwp_datapipe_new))
40+
data_old = next(iter(nwp_datapipe_old))
41+
42+
assert data_new.shape[-1] == 548
43+
assert data_new.shape[-2] == 704
44+
45+
# check first values are different
46+
assert data_new.x_osgb.values[0] != data_old.x_osgb.values[0]
47+
assert data_new.y_osgb.values[0] != data_old.y_osgb.values[0]
48+
49+
# check middle valyes are the same
50+
assert data_new.x_osgb.values[274] == data_old.x_osgb.values[274]
51+
assert data_new.y_osgb.values[352] == data_old.y_osgb.values[352]

0 commit comments

Comments
 (0)