diff --git a/scripts/save_batches.py b/scripts/save_batches.py index 10bdb527..3f8dbaf8 100644 --- a/scripts/save_batches.py +++ b/scripts/save_batches.py @@ -30,6 +30,7 @@ import torch from ocf_datapipes.training.pvnet import pvnet_datapipe from ocf_datapipes.training.windnet import windnet_datapipe +from ocf_datapipes.training.pvnet_site import pvnet_site_datapipe from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc from torch.utils.data import DataLoader @@ -63,6 +64,8 @@ def _get_datapipe(config_path, start_time, end_time, batch_size, renewable: str data_pipeline_fn = pvnet_datapipe elif renewable == "wind": data_pipeline_fn = windnet_datapipe + elif renewable == "pv_india": + data_pipeline_fn = pvnet_site_datapipe else: raise ValueError(f"Unknown renewable: {renewable}") data_pipeline = data_pipeline_fn(