Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PVNet Site Datapipe #267

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
399 changes: 399 additions & 0 deletions ocf_datapipes/training/pvnet_site.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,399 @@
"""Create the training/validation datapipe for training the PVNet Model"""
import logging
from datetime import datetime, timedelta
from typing import List, Optional

Check warning on line 4 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L2-L4

Added lines #L2 - L4 were not covered by tests

import xarray as xr
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data.datapipes.iter import IterableWrapper

Check warning on line 8 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L6-L8

Added lines #L6 - L8 were not covered by tests

from ocf_datapipes.batch import MergeNumpyModalities, MergeNWPNumpyModalities
from ocf_datapipes.config.model import Configuration
from ocf_datapipes.load import (

Check warning on line 12 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L10-L12

Added lines #L10 - L12 were not covered by tests
OpenConfiguration,
)
from ocf_datapipes.training.common import (

Check warning on line 15 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L15

Added line #L15 was not covered by tests
DatapipeKeyForker,
_get_datapipes_dict,
concat_xr_time_utc,
construct_loctime_pipelines,
fill_nans_in_arrays,
fill_nans_in_pv,
normalize_gsp,
slice_datapipes_by_time,
)
from ocf_datapipes.utils.consts import (

Check warning on line 25 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L25

Added line #L25 was not covered by tests
NWP_MEANS,
NWP_STDS,
RSS_MEAN,
RSS_STD,
)
from ocf_datapipes.utils.utils import (

Check warning on line 31 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L31

Added line #L31 was not covered by tests
combine_to_single_dataset,
flatten_nwp_source_dict,
nest_nwp_source_dict,
uncombine_from_single_dataset,
)

xr.set_options(keep_attrs=True)
logger = logging.getLogger("pvnet_site_datapipe")

Check warning on line 39 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L38-L39

Added lines #L38 - L39 were not covered by tests


def normalize_pv(x: xr.DataArray):

Check warning on line 42 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L42

Added line #L42 was not covered by tests
"""Normalize PV data"""
return x / 3500.0 # TODO Check the actual max value

Check warning on line 44 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L44

Added line #L44 was not covered by tests


class DictDatasetIterDataPipe(IterDataPipe):

Check warning on line 47 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L47

Added line #L47 was not covered by tests
"""Create a dictionary of xr.Datasets from a dict of datapipes"""

datapipes_dict: dict[IterDataPipe]
length: Optional[int]

Check warning on line 51 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L50-L51

Added lines #L50 - L51 were not covered by tests

def __init__(self, datapipes_dict: dict[IterDataPipe]):

Check warning on line 53 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L53

Added line #L53 was not covered by tests
"""Init"""
# Flatten the dict of input datapipes (NWP is nested)
self.datapipes_dict = flatten_nwp_source_dict(datapipes_dict)
self.length = None

Check warning on line 57 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L56-L57

Added lines #L56 - L57 were not covered by tests

# Run checks
is_okay = all([isinstance(dp, IterDataPipe) for k, dp in self.datapipes_dict.items()])

Check warning on line 60 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L60

Added line #L60 was not covered by tests

if not is_okay:
raise TypeError(

Check warning on line 63 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L62-L63

Added lines #L62 - L63 were not covered by tests
"All inputs are required to be `IterDataPipe` " "for `ZipIterDataPipe`."
)

super().__init__()

Check warning on line 67 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L67

Added line #L67 was not covered by tests

def __iter__(self):

Check warning on line 69 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L69

Added line #L69 was not covered by tests
"""Iter"""
all_keys = []
all_datapipes = []
for k, dp in self.datapipes_dict.items():
all_keys += [k]
all_datapipes += [dp]

Check warning on line 75 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L71-L75

Added lines #L71 - L75 were not covered by tests

zipped_datapipes = all_datapipes[0].zip_ocf(*all_datapipes[1:])

Check warning on line 77 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L77

Added line #L77 was not covered by tests

for values in zipped_datapipes:
output_dict = {key: x for key, x in zip(all_keys, values)}

Check warning on line 80 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L79-L80

Added lines #L79 - L80 were not covered by tests

# re-nest the nwp keys
output_dict = nest_nwp_source_dict(output_dict)

Check warning on line 83 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L83

Added line #L83 was not covered by tests

yield output_dict

Check warning on line 85 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L85

Added line #L85 was not covered by tests


class LoadDictDatasetIterDataPipe(IterDataPipe):

Check warning on line 88 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L88

Added line #L88 was not covered by tests
"""Load NetCDF files and split them back into individual xr.Datasets"""

filenames: List[str]
keys: List[str]

Check warning on line 92 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L91-L92

Added lines #L91 - L92 were not covered by tests

def __init__(self, filenames: List[str], keys: List[str]):

Check warning on line 94 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L94

Added line #L94 was not covered by tests
"""
Load NetCDF files and split them back into individual xr.Datasets

Args:
filenames: List of filesnames to load
keys: List of keys from each file to use, each key should be a
dataarray in the xr.Dataset
"""
super().__init__()
self.keys = keys
self.filenames = filenames

Check warning on line 105 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L103-L105

Added lines #L103 - L105 were not covered by tests

def __iter__(self):

Check warning on line 107 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L107

Added line #L107 was not covered by tests
"""Iterate through each filename, loading it, uncombining it, and then yielding it"""
while True:
for filename in self.filenames:
dataset = xr.open_dataset(filename)
datasets = uncombine_from_single_dataset(dataset)

Check warning on line 112 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L109-L112

Added lines #L109 - L112 were not covered by tests
# Yield a dictionary of the data, using the keys in self.keys
dataset_dict = {}
for k in self.keys:
dataset_dict[k] = datasets[k]
yield dataset_dict

Check warning on line 117 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L114-L117

Added lines #L114 - L117 were not covered by tests


@functional_datapipe("pvnet_site_convert_to_numpy_batch")
class ConvertToNumpyBatchIterDataPipe(IterDataPipe):

Check warning on line 121 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L120-L121

Added lines #L120 - L121 were not covered by tests
"""Converts Xarray Dataset to Numpy Batch"""

def __init__(

Check warning on line 124 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L124

Added line #L124 was not covered by tests
self,
dataset_dict_dp: IterDataPipe,
configuration: Configuration,
check_satellite_no_zeros: bool = False,
):
"""Init"""
super().__init__()
self.dataset_dict_dp = dataset_dict_dp
self.configuration = configuration
self.check_satellite_no_zeros = check_satellite_no_zeros

Check warning on line 134 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L131-L134

Added lines #L131 - L134 were not covered by tests

def __iter__(self):

Check warning on line 136 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L136

Added line #L136 was not covered by tests
"""Iter"""
for datapipes_dict in self.dataset_dict_dp:

Check warning on line 138 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L138

Added line #L138 was not covered by tests
# Spatially slice, normalize, and convert data to numpy arrays
numpy_modalities = []

Check warning on line 140 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L140

Added line #L140 was not covered by tests

if "nwp" in datapipes_dict:

Check warning on line 142 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L142

Added line #L142 was not covered by tests
# Combine the NWPs into NumpyBatch
nwp_numpy_modalities = dict()
for nwp_key, nwp_datapipe in datapipes_dict["nwp"].items():
nwp_numpy_modalities[nwp_key] = nwp_datapipe.convert_nwp_to_numpy_batch()

Check warning on line 146 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L144-L146

Added lines #L144 - L146 were not covered by tests

nwp_numpy_modalities = MergeNWPNumpyModalities(nwp_numpy_modalities)
numpy_modalities.append(nwp_numpy_modalities)

Check warning on line 149 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L148-L149

Added lines #L148 - L149 were not covered by tests

if "sat" in datapipes_dict:
numpy_modalities.append(datapipes_dict["sat"].convert_satellite_to_numpy_batch())
if "pv" in datapipes_dict:
numpy_modalities.append(datapipes_dict["pv"].convert_pv_to_numpy_batch())
if "gsp" in datapipes_dict:
numpy_modalities.append(datapipes_dict["gsp"].convert_gsp_to_numpy_batch())
if "sensor" in datapipes_dict:
numpy_modalities.append(datapipes_dict["sensor"].convert_sensor_to_numpy_batch())
if "wind" in datapipes_dict:
numpy_modalities.append(datapipes_dict["wind"].convert_wind_to_numpy_batch())

Check warning on line 160 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L151-L160

Added lines #L151 - L160 were not covered by tests

logger.debug("Combine all the data sources")
combined_datapipe = MergeNumpyModalities(numpy_modalities)

Check warning on line 163 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L162-L163

Added lines #L162 - L163 were not covered by tests

logger.info("Filtering out samples with no data")

Check warning on line 165 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L165

Added line #L165 was not covered by tests
# if self.check_satellite_no_zeros:
# in production we don't want any nans in the satellite data
# combined_datapipe = combined_datapipe.map(check_nans_in_satellite_data)

combined_datapipe = combined_datapipe.map(fill_nans_in_arrays)

Check warning on line 170 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L170

Added line #L170 was not covered by tests

yield next(iter(combined_datapipe))

Check warning on line 172 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L172

Added line #L172 was not covered by tests


def minutes(num_mins: int):

Check warning on line 175 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L175

Added line #L175 was not covered by tests
"""Timedelta of a number of minutes.

Args:
num_mins: Minutes timedelta.
"""
return timedelta(minutes=num_mins)

Check warning on line 181 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L181

Added line #L181 was not covered by tests


def construct_sliced_data_pipeline(

Check warning on line 184 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L184

Added line #L184 was not covered by tests
config_filename: str,
location_pipe: IterDataPipe,
t0_datapipe: IterDataPipe,
production: bool = False,
) -> dict:
"""Constructs data pipeline for the input data config file.

This yields samples from the location and time datapipes.

Args:
config_filename: Path to config file.
location_pipe: Datapipe yielding locations.
t0_datapipe: Datapipe yielding times.
production: Whether constucting pipeline for production inference.
"""

datapipes_dict = _get_datapipes_dict(

Check warning on line 201 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L201

Added line #L201 was not covered by tests
config_filename,
production=production,
)

configuration = datapipes_dict.pop("config")

Check warning on line 206 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L206

Added line #L206 was not covered by tests

# Unpack for convenience
conf_sat = configuration.input_data.satellite
conf_nwp = configuration.input_data.nwp

Check warning on line 210 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L209-L210

Added lines #L209 - L210 were not covered by tests

# Slice all of the datasets by time - this is an in-place operation
slice_datapipes_by_time(datapipes_dict, t0_datapipe, configuration, production)

Check warning on line 213 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L213

Added line #L213 was not covered by tests

# We need a copy of the location datapipe for all keys in fork_keys
fork_keys = set(k for k in datapipes_dict.keys())
if "nwp" in datapipes_dict: # NWP is nested
fork_keys.update(set(f"nwp/{k}" for k in datapipes_dict["nwp"].keys()))

Check warning on line 218 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L216-L218

Added lines #L216 - L218 were not covered by tests

# We don't need somes keys even if they are in the data dictionary
fork_keys = fork_keys - set(

Check warning on line 221 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L221

Added line #L221 was not covered by tests
["topo", "nwp", "wind", "wind_future", "sensor", "hrv", "pv_future", "pv"]
)

# Set up a key-forker for all the data sources we need it for
get_loc_datapipe = DatapipeKeyForker(fork_keys, location_pipe)

Check warning on line 226 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L226

Added line #L226 was not covered by tests

if "nwp" in datapipes_dict:
nwp_datapipes_dict = dict()

Check warning on line 229 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L228-L229

Added lines #L228 - L229 were not covered by tests

for nwp_key, nwp_datapipe in datapipes_dict["nwp"].items():
location_pipe, location_pipe_copy = location_pipe.fork(2, buffer_size=5)
nwp_datapipe = nwp_datapipe.select_spatial_slice_pixels(

Check warning on line 233 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L231-L233

Added lines #L231 - L233 were not covered by tests
get_loc_datapipe(f"nwp/{nwp_key}"),
roi_height_pixels=conf_nwp[nwp_key].nwp_image_size_pixels_height,
roi_width_pixels=conf_nwp[nwp_key].nwp_image_size_pixels_width,
)
nwp_datapipes_dict[nwp_key] = nwp_datapipe.normalize(

Check warning on line 238 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L238

Added line #L238 was not covered by tests
mean=NWP_MEANS[conf_nwp[nwp_key].nwp_provider],
std=NWP_STDS[conf_nwp[nwp_key].nwp_provider],
)

if "sat" in datapipes_dict:
sat_datapipe = datapipes_dict["sat"]

Check warning on line 244 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L243-L244

Added lines #L243 - L244 were not covered by tests

sat_datapipe = sat_datapipe.select_spatial_slice_pixels(

Check warning on line 246 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L246

Added line #L246 was not covered by tests
get_loc_datapipe("sat"),
roi_height_pixels=conf_sat.satellite_image_size_pixels_height,
roi_width_pixels=conf_sat.satellite_image_size_pixels_width,
)
sat_datapipe = sat_datapipe.normalize(mean=RSS_MEAN, std=RSS_STD)

Check warning on line 251 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L251

Added line #L251 was not covered by tests

if "pv" in datapipes_dict:

Check warning on line 253 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L253

Added line #L253 was not covered by tests
# Recombine Sensor arrays - see function doc for further explanation
pv_datapipe = (

Check warning on line 255 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L255

Added line #L255 was not covered by tests
datapipes_dict["pv"].zip_ocf(datapipes_dict["pv_future"]).map(concat_xr_time_utc)
)
pv_datapipe = pv_datapipe.normalize(normalize_fn=normalize_pv)
pv_datapipe = pv_datapipe.map(fill_nans_in_pv)

Check warning on line 259 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L258-L259

Added lines #L258 - L259 were not covered by tests

finished_dataset_dict = {"config": configuration}

Check warning on line 261 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L261

Added line #L261 was not covered by tests

if "gsp" in datapipes_dict:
gsp_future_datapipe = datapipes_dict["gsp_future"]
gsp_future_datapipe = gsp_future_datapipe.select_spatial_slice_meters(

Check warning on line 265 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L263-L265

Added lines #L263 - L265 were not covered by tests
location_datapipe=get_loc_datapipe("gsp_future"),
roi_height_meters=1,
roi_width_meters=1,
dim_name="gsp_id",
)

gsp_datapipe = datapipes_dict["gsp"]
gsp_datapipe = gsp_datapipe.select_spatial_slice_meters(

Check warning on line 273 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L272-L273

Added lines #L272 - L273 were not covered by tests
location_datapipe=get_loc_datapipe("gsp"),
roi_height_meters=1,
roi_width_meters=1,
dim_name="gsp_id",
)

# Recombine GSP arrays - see function doc for further explanation
gsp_datapipe = gsp_datapipe.zip_ocf(gsp_future_datapipe).map(concat_xr_time_utc)
gsp_datapipe = gsp_datapipe.normalize(normalize_fn=normalize_gsp)

Check warning on line 282 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L281-L282

Added lines #L281 - L282 were not covered by tests

finished_dataset_dict["gsp"] = gsp_datapipe

Check warning on line 284 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L284

Added line #L284 was not covered by tests

get_loc_datapipe.close()

Check warning on line 286 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L286

Added line #L286 was not covered by tests

if "nwp" in datapipes_dict:
finished_dataset_dict["nwp"] = nwp_datapipes_dict
if "sat" in datapipes_dict:
finished_dataset_dict["sat"] = sat_datapipe
if "pv" in datapipes_dict:
finished_dataset_dict["pv"] = pv_datapipe

Check warning on line 293 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L288-L293

Added lines #L288 - L293 were not covered by tests

return finished_dataset_dict

Check warning on line 295 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L295

Added line #L295 was not covered by tests


def pvnet_site_datapipe(

Check warning on line 298 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L298

Added line #L298 was not covered by tests
config_filename: str,
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
) -> IterDataPipe:
"""
Construct PVNet site pipeline for the input data config file.

Args:
config_filename: Path to config file.
start_time: Minimum time at which a sample can be selected.
end_time: Maximum time at which a sample can be selected.
"""
logger.info("Constructing pvnet site pipeline")

Check warning on line 311 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L311

Added line #L311 was not covered by tests

# Open datasets from the config and filter to useable location-time pairs
location_pipe, t0_datapipe = construct_loctime_pipelines(

Check warning on line 314 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L314

Added line #L314 was not covered by tests
config_filename,
start_time,
end_time,
)

# Shard after we have the loc-times. These are already shuffled so no need to shuffle again
location_pipe = location_pipe.sharding_filter()
t0_datapipe = t0_datapipe.sharding_filter()

Check warning on line 322 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L321-L322

Added lines #L321 - L322 were not covered by tests

# In this function we re-open the datasets to make a clean separation before/after sharding
# This function
datapipe_dict = construct_sliced_data_pipeline(

Check warning on line 326 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L326

Added line #L326 was not covered by tests
config_filename,
location_pipe,
t0_datapipe,
)

# Merge all the datapipes into one
return DictDatasetIterDataPipe(

Check warning on line 333 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L333

Added line #L333 was not covered by tests
{k: v for k, v in datapipe_dict.items() if k != "config"},
).map(combine_to_single_dataset)


def split_dataset_dict_dp(element):

Check warning on line 338 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L338

Added line #L338 was not covered by tests
"""
Wrap each of the data source inputs into a datapipe

Args:
element: Dictionary of xarray objects
"""

element = flatten_nwp_source_dict(element)
output_dict = {k: IterableWrapper([v]) for k, v in element.items() if k != "config"}
output_dict = nest_nwp_source_dict(output_dict)

Check warning on line 348 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L346-L348

Added lines #L346 - L348 were not covered by tests

return output_dict

Check warning on line 350 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L350

Added line #L350 was not covered by tests


def pvnet_site_netcdf_datapipe(

Check warning on line 353 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L353

Added line #L353 was not covered by tests
config_filename: str,
keys: List[str],
filenames: List[str],
) -> IterDataPipe:
"""
Load the saved Datapipes from pvnet site, and transform to numpy batch

Args:
config_filename: Path to config file.
keys: List of keys to extract from the single NetCDF files
filenames: List of NetCDF files to load

Returns:
Datapipe that transforms the NetCDF files to numpy batch
"""
logger.info("Constructing pvnet site file pipeline")
config_datapipe = OpenConfiguration(config_filename)
configuration: Configuration = next(iter(config_datapipe))

Check warning on line 371 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L369-L371

Added lines #L369 - L371 were not covered by tests
# Load files
datapipe_dict_dp: IterDataPipe = LoadDictDatasetIterDataPipe(

Check warning on line 373 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L373

Added line #L373 was not covered by tests
filenames=filenames,
keys=keys,
).map(split_dataset_dict_dp)
datapipe = datapipe_dict_dp.pvnet_site_convert_to_numpy_batch(configuration=configuration)

Check warning on line 377 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L377

Added line #L377 was not covered by tests

return datapipe

Check warning on line 379 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L379

Added line #L379 was not covered by tests


if __name__ == "__main__":

Check warning on line 382 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L382

Added line #L382 was not covered by tests
# Load the ECMWF and sensor data here
datapipe = pvnet_site_datapipe(

Check warning on line 384 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L384

Added line #L384 was not covered by tests
config_filename="/home/jacob/Development/ocf_datapipes/tests/config/india_test.yaml",
start_time=datetime(2023, 1, 1),
end_time=datetime(2023, 11, 2),
)
batch = next(iter(datapipe))
print(batch)
batch.to_netcdf("test.nc", engine="h5netcdf")

Check warning on line 391 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L389-L391

Added lines #L389 - L391 were not covered by tests
# Load the saved NetCDF files here
datapipe = pvnet_site_netcdf_datapipe(

Check warning on line 393 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L393

Added line #L393 was not covered by tests
config_filename="/home/jacob/Development/ocf_datapipes/tests/config/india_test.yaml",
keys=["nwp", "pv"],
filenames=["test.nc"],
)
batch = next(iter(datapipe))
print(batch)

Check warning on line 399 in ocf_datapipes/training/pvnet_site.py

View check run for this annotation

Codecov / codecov/patch

ocf_datapipes/training/pvnet_site.py#L398-L399

Added lines #L398 - L399 were not covered by tests
Loading