Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 1b336d4

Browse files
committed
Merge branch 'main' into issue/visulaize
2 parents 859546e + a74fbb4 commit 1b336d4

File tree

10 files changed

+140
-34
lines changed

10 files changed

+140
-34
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[bumpversion]
22
commit = True
33
tag = True
4-
current_version = 3.3.46
4+
current_version = 3.3.52
55
message = Bump version: {current_version} → {new_version} [skip ci]
66

77
[bumpversion:file:setup.py]

ocf_datapipes/config/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,11 @@ class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
502502
description="The temporal resolution (in minutes) of the data."
503503
"Note that this needs to be divisible by 5.",
504504
)
505+
satellite_scaling_methods: Optional[List[str]] = Field(
506+
["mean_std"],
507+
description="There are few ways to scale the satellite data. "
508+
"1. None, 2. mean_std, 3. min_max",
509+
)
505510

506511

507512
class HRVSatellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):

ocf_datapipes/load/nwp/nwp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
self.open_nwp = open_icon_eu
4343
elif provider.lower() == "icon-global":
4444
self.open_nwp = open_icon_global
45-
elif provider.lower() == "ecmwf":
45+
elif provider.lower() in ("ecmwf", "mo_global"): # same schema so using the same loader
4646
self.open_nwp = open_ifs
4747
elif provider.lower() == "gfs":
4848
self.open_nwp = open_gfs

ocf_datapipes/select/select_spatial_slice.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,24 +103,35 @@ def _get_idx_of_pixel_closest_to_poi(
103103

104104
def _get_idx_of_pixel_closest_to_poi_geostationary(
105105
xr_data: xr.DataArray,
106-
center_osgb: Location,
106+
center_coordinate: Location,
107107
) -> Location:
108108
"""
109109
Return x and y index location of pixel at center of region of interest.
110110
111111
Args:
112112
xr_data: Xarray dataset
113-
center_osgb: Center in OSGB coordinates
113+
center_coordinate: Central coordinate
114114
115115
Returns:
116116
Location for the center pixel in geostationary coordinates
117117
"""
118-
119118
xr_coords, xr_x_dim, xr_y_dim = spatial_coord_type(xr_data)
119+
if center_coordinate.coordinate_system == "osgb":
120+
x, y = osgb_to_geostationary_area_coords(
121+
x=center_coordinate.x, y=center_coordinate.y, xr_data=xr_data
122+
)
123+
elif center_coordinate.coordinate_system == "lon_lat":
124+
x, y = lon_lat_to_geostationary_area_coords(
125+
x=center_coordinate.x, y=center_coordinate.y, xr_data=xr_data
126+
)
127+
else:
128+
raise NotImplementedError(
129+
f"Only 'osgb' and 'lon_lat' location coordinates are \
130+
supported in conversion to geostationary \
131+
- not '{center_coordinate.coordinate_system}'"
132+
)
120133

121-
x, y = osgb_to_geostationary_area_coords(x=center_osgb.x, y=center_osgb.y, xr_data=xr_data)
122134
center_geostationary = Location(x=x, y=y, coordinate_system="geostationary")
123-
124135
# Check that the requested point lies within the data
125136
assert xr_data[xr_x_dim].min() < x < xr_data[xr_x_dim].max()
126137
assert xr_data[xr_y_dim].min() < y < xr_data[xr_y_dim].max()
@@ -390,7 +401,7 @@ def select_spatial_slice_pixels(
390401
if xr_coords == "geostationary":
391402
center_idx: Location = _get_idx_of_pixel_closest_to_poi_geostationary(
392403
xr_data=xr_data,
393-
center_osgb=location,
404+
center_coordinate=location,
394405
)
395406
else:
396407
center_idx: Location = _get_idx_of_pixel_closest_to_poi(

ocf_datapipes/training/pvnet_site.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
NWP_MEANS,
2525
NWP_STDS,
2626
RSS_MEAN,
27+
RSS_RAW_MAX,
28+
RSS_RAW_MIN,
2729
RSS_STD,
2830
)
2931
from ocf_datapipes.utils.utils import (
@@ -36,29 +38,10 @@
3638
xr.set_options(keep_attrs=True)
3739
logger = logging.getLogger("pvnet_site_datapipe")
3840

39-
normalization_values = {
40-
2019: 3185.0,
41-
2020: 2678.0,
42-
2021: 3196.0,
43-
2022: 3575.0,
44-
2023: 3773.0,
45-
2024: 3773.0,
46-
}
47-
4841

4942
def normalize_pv(x: xr.DataArray):
5043
"""Normalize PV data"""
51-
# This is after the data has been temporally sliced, so have the year
52-
return x / normalization_values[2024]
53-
54-
year = x.time_utc.dt.year
55-
56-
# Add the effective_capacity_mwp to the dataset, indexed on the time_utc
57-
return (
58-
x / normalization_values[year]
59-
if year in normalization_values
60-
else x / normalization_values[2024]
61-
)
44+
return x / x.nominal_capacity_wp
6245

6346

6447
class DictDatasetIterDataPipe(IterDataPipe):
@@ -273,7 +256,11 @@ def construct_sliced_data_pipeline(
273256
roi_height_pixels=conf_sat.satellite_image_size_pixels_height,
274257
roi_width_pixels=conf_sat.satellite_image_size_pixels_width,
275258
)
276-
sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD)
259+
scaling_methods = conf_sat.satellite_scaling_methods
260+
if "min_max" in scaling_methods:
261+
sat_datapipe = sat_datapipe.normalize(min_values=RSS_RAW_MIN, max_values=RSS_RAW_MAX)
262+
if "mean_std" in scaling_methods:
263+
sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD)
277264

278265
if "pv" in datapipes_dict:
279266
# Recombine Sensor arrays - see function doc for further explanation

ocf_datapipes/transform/xarray/normalize.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(
2222
max_value: Optional[Union[int, float]] = None,
2323
calculate_mean_std_from_example: bool = False,
2424
normalize_fn: Optional[Callable] = None,
25+
min_values: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]] = None,
26+
max_values: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]] = None,
2527
):
2628
"""
2729
Normalize the data with either given mean/std,
@@ -37,13 +39,17 @@ def __init__(
3739
calculate_mean_std_from_example: Whether to calculate the
3840
mean/std from the input data or not
3941
normalize_fn: Callable function to apply to the data to normalize it
42+
min_values: Min values for each channel
43+
max_values: Max values for each channel
4044
"""
4145
self.source_datapipe = source_datapipe
4246
self.mean = mean
4347
self.std = std
4448
self.max_value = max_value
4549
self.calculate_mean_std_from_example = calculate_mean_std_from_example
4650
self.normalize_fn = normalize_fn
51+
self.min_values = min_values
52+
self.max_values = max_values
4753

4854
def __iter__(self) -> Union[xr.Dataset, xr.DataArray]:
4955
"""Normalize the data depending on the init arguments"""
@@ -61,6 +67,8 @@ def __iter__(self) -> Union[xr.Dataset, xr.DataArray]:
6167
# For Topo data for example
6268
xr_data -= xr_data.mean().item()
6369
xr_data /= xr_data.std().item()
70+
elif (self.min_values is not None) and (self.max_values is not None):
71+
xr_data = (xr_data - self.min_values) / (self.max_values - self.min_values)
6472
else:
6573
try:
6674
logger.debug(f"Normalizing by {self.normalize_fn}")

ocf_datapipes/utils/consts.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __getitem__(self, key):
4646
"excarta",
4747
"merra2",
4848
"merra2_uk",
49+
"mo_global",
4950
]
5051

5152
# ------ UKV
@@ -131,6 +132,24 @@ def __getitem__(self, key):
131132
UKV_STD = _to_data_array(UKV_STD)
132133
UKV_MEAN = _to_data_array(UKV_MEAN)
133134

135+
# These were calculated from 200 random init times (step 0s) from the MO global data
136+
MO_GLOBAL_INDIA_MEAN = {
137+
"temperature_sl": 298.2,
138+
"wind_u_component_10m": 0.5732,
139+
"wind_v_component_10m": -0.2831,
140+
}
141+
142+
MO_GLOBAL_INDIA_STD = {
143+
"temperature_sl": 8.473,
144+
"wind_u_component_10m": 2.599,
145+
"wind_v_component_10m": 2.016,
146+
}
147+
148+
149+
MO_GLOBAL_VARIABLE_NAMES = tuple(MO_GLOBAL_INDIA_MEAN.keys())
150+
MO_GLOBAL_INDIA_STD = _to_data_array(MO_GLOBAL_INDIA_STD)
151+
MO_GLOBAL_INDIA_MEAN = _to_data_array(MO_GLOBAL_INDIA_MEAN)
152+
134153

135154
# ------ GFS
136155
GFS_STD = {
@@ -250,6 +269,10 @@ def __getitem__(self, key):
250269
"v10": 0.02332865633070469,
251270
"v100": -0.07577426731586456,
252271
"v200": -0.1255049854516983,
272+
"diff_dlwrf": 1340142.4,
273+
"diff_dswrf": 820569.5,
274+
"diff_duvrs": 94480.24,
275+
"diff_sr": 814910.1,
253276
}
254277

255278
INDIA_ECMWF_STD = {
@@ -270,6 +293,10 @@ def __getitem__(self, key):
270293
"v10": 2.401158571243286,
271294
"v100": 3.5278923511505127,
272295
"v200": 3.974159002304077,
296+
"diff_dlwrf": 292804.8,
297+
"diff_dswrf": 1082344.9,
298+
"diff_duvrs": 125904.18,
299+
"diff_sr": 1088536.2,
273300
}
274301

275302

@@ -347,6 +374,7 @@ def __getitem__(self, key):
347374
excarta=EXCARTA_VARIABLE_NAMES,
348375
merra2=MERRA2_VARIABLE_NAMES,
349376
merra2_uk=UK_MERRA2_VARIABLE_NAMES,
377+
mo_global=MO_GLOBAL_VARIABLE_NAMES,
350378
)
351379
NWP_STDS = NWPStatDict(
352380
ukv=UKV_STD,
@@ -356,6 +384,7 @@ def __getitem__(self, key):
356384
excarta=EXCARTA_STD,
357385
merra2=MERRA2_STD,
358386
merra2_uk=UK_MERRA2_STD,
387+
mo_global=MO_GLOBAL_INDIA_STD,
359388
)
360389
NWP_MEANS = NWPStatDict(
361390
ukv=UKV_MEAN,
@@ -365,6 +394,7 @@ def __getitem__(self, key):
365394
excarta=EXCARTA_MEAN,
366395
merra2=MERRA2_MEAN,
367396
merra2_uk=UK_MERRA2_MEAN,
397+
mo_global=MO_GLOBAL_INDIA_MEAN,
368398
)
369399

370400
# --------------------------- SATELLITE ------------------------------
@@ -405,6 +435,41 @@ def __getitem__(self, key):
405435
RSS_STD = _to_data_array(RSS_STD)
406436
RSS_MEAN = _to_data_array(RSS_MEAN)
407437

438+
# normalizing from raw values
439+
440+
RSS_RAW_MIN = {
441+
"IR_016": -2.5118103,
442+
"IR_039": -64.83977,
443+
"IR_087": 63.404694,
444+
"IR_097": 2.844452,
445+
"IR_108": 199.10002,
446+
"IR_120": -17.254883,
447+
"IR_134": -26.29155,
448+
"VIS006": -1.1009827,
449+
"VIS008": -2.4184198,
450+
"WV_062": 199.57048,
451+
"WV_073": 198.95093,
452+
"HRV": -1.2278595,
453+
}
454+
455+
RSS_RAW_MAX = {
456+
"IR_016": 69.60857,
457+
"IR_039": 339.15588,
458+
"IR_087": 340.26526,
459+
"IR_097": 317.86752,
460+
"IR_108": 313.2767,
461+
"IR_120": 315.99194,
462+
"IR_134": 274.82297,
463+
"VIS006": 93.786545,
464+
"VIS008": 101.34922,
465+
"WV_062": 249.91806,
466+
"WV_073": 286.96323,
467+
"HRV": 103.90016,
468+
}
469+
470+
RSS_RAW_MIN = _to_data_array(RSS_RAW_MIN)
471+
RSS_RAW_MAX = _to_data_array(RSS_RAW_MAX)
472+
408473

409474
# --------------------------- SENSORS --------------------------------
410475

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torch>=2.0.0
1+
torch>=2.0.0, <2.5.0
22
Cartopy>=0.20.3
33
xarray
44
zarr

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
setup(
1313
name="ocf_datapipes",
14-
version="3.3.46",
14+
version="3.3.52",
1515
license="MIT",
1616
description="Pytorch Datapipes built for use in Open Climate Fix's forecasting work",
1717
author="Jacob Bieker, Jack Kelly, Peter Dudfield, James Fulton",

tests/select/test_select_spatial_slice.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import numpy as np
22
import xarray as xr
3-
from ocf_datapipes.utils import Location
43

54
from ocf_datapipes.select import (
65
PickLocations,
76
SelectSpatialSliceMeters,
87
SelectSpatialSlicePixels,
98
)
10-
11-
from ocf_datapipes.select.select_spatial_slice import slice_spatial_pixel_window_from_xarray
9+
from ocf_datapipes.select.select_spatial_slice import (
10+
_get_idx_of_pixel_closest_to_poi_geostationary,
11+
slice_spatial_pixel_window_from_xarray,
12+
)
13+
from ocf_datapipes.utils import Location
1214

1315

1416
def test_slice_spatial_pixel_window_from_xarray_function():
@@ -158,3 +160,31 @@ def test_select_spatial_slice_meters_icon_global(passiv_datapipe, icon_global_da
158160
# ICON global has roughly 13km spacing, so this should be around 7x7 grid
159161
assert len(data.longitude) == 49
160162
assert len(data.latitude) == 49
163+
164+
165+
def test_get_idx_of_pixel_closest_to_poi_geostationary_lon_lat_location():
166+
# Create dummy data
167+
x = np.arange(5000000, -5000000, -5000)
168+
y = np.arange(5000000, -5000000, -5000)[::-1]
169+
170+
xr_data = xr.Dataset(
171+
data_vars=dict(
172+
data=(["x_geostationary", "y_geostationary"], np.random.normal(size=(len(x), len(y)))),
173+
),
174+
coords=dict(
175+
x_geostationary=(["x_geostationary"], x),
176+
y_geostationary=(["y_geostationary"], y),
177+
),
178+
)
179+
xr_data.attrs["area"] = (
180+
"msg_seviri_iodc_3km:\n description: MSG SEVIRI Indian Ocean Data Coverage service area definition with\n 3 km resolution\n projection:\n proj: geos\n lon_0: 41.5\n h: 35785831\n x_0: 0\n y_0: 0\n a: 6378169\n rf: 295.488065897014\n no_defs: null\n type: crs\n shape:\n height: 3712\n width: 3712\n area_extent:\n lower_left_xy: [5000000, 5000000]\n upper_right_xy: [-5000000, -5000000]\n units: m\n"
181+
)
182+
183+
center = Location(x=77.1, y=28.6, coordinate_system="lon_lat")
184+
185+
location_center_idx = _get_idx_of_pixel_closest_to_poi_geostationary(
186+
xr_data=xr_data, center_coordinate=center
187+
)
188+
189+
assert location_center_idx.coordinate_system == "idx"
190+
assert location_center_idx.x == 2000

0 commit comments

Comments
 (0)