From 075ae315c93c98ff345561603e9e834c8a1218e4 Mon Sep 17 00:00:00 2001 From: Sukhil Patel <42407101+Sukh-P@users.noreply.github.com> Date: Tue, 21 Jan 2025 11:57:27 +0000 Subject: [PATCH] Add sample saving for Site Dataset + presaved SiteDataset/DataModule for loading presaved samples (#290) --- README.md | 12 +- pvnet/data/__init__.py | 2 + pvnet/data/base.py | 116 ------------------ pvnet/data/base_datamodule.py | 93 ++++++++++++++ pvnet/data/pv_site_datamodule.py | 67 ---------- pvnet/data/site_datamodule.py | 80 ++++++++++++ ...atamodule.py => uk_regional_datamodule.py} | 0 pvnet/data/wind_datamodule.py | 62 ---------- pyproject.toml | 2 +- scripts/save_samples.py | 22 ++-- tests/conftest.py | 7 +- tests/data/test_datamodule.py | 81 ++++-------- .../multimodal/site_encoders/test_encoders.py | 15 +-- .../data_configuration.yaml | 0 .../datamodule.yaml | 0 .../train/000000.nc | Bin .../train/000001.nc | Bin 17 files changed, 230 insertions(+), 329 deletions(-) delete mode 100644 pvnet/data/base.py create mode 100644 pvnet/data/base_datamodule.py delete mode 100644 pvnet/data/pv_site_datamodule.py create mode 100644 pvnet/data/site_datamodule.py rename pvnet/data/{datamodule.py => uk_regional_datamodule.py} (100%) delete mode 100644 pvnet/data/wind_datamodule.py rename tests/test_data/{sample_wind_batches => sample_site_batches}/data_configuration.yaml (100%) rename tests/test_data/{sample_wind_batches => sample_site_batches}/datamodule.yaml (100%) rename tests/test_data/{sample_wind_batches => sample_site_batches}/train/000000.nc (100%) rename tests/test_data/{sample_wind_batches => sample_site_batches}/train/000001.nc (100%) diff --git a/README.md b/README.md index 051e9a2e..69f54c64 100644 --- a/README.md +++ b/README.md @@ -145,20 +145,20 @@ This is also where you can update the train, val & test periods to cover the dat ### Running the batch creation script -Run the `save_batches.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example): +Run the `save_samples.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example): ```bash -python scripts/save_batches.py +python scripts/save_samples.py ``` PVNet uses [hydra](https://hydra.cc/) which enables us to pass variables via the command line that will override the configuration defined in the `./configs` directory, like this: ```bash -python scripts/save_batches.py datamodule=streamed_batches datamodule.batch_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5 +python scripts/save_samples.py datamodule=streamed_batches datamodule.sample_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5 ``` -`scripts/save_batches.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder. +`scripts/save_samples.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder. If downloading private data from a GCP bucket make sure to authenticate gcloud (the public satellite data does not need authentication): @@ -197,7 +197,7 @@ Make sure to update the following config files before training your model: 2. In `configs/model/local_multimodal.yaml`: - update the list of encoders to reflect the data sources you are using. If you are using different NWP sources, the encoders for these should follow the same structure with two important updates: - `in_channels`: number of variables your NWP source supplies - - `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `nwp_image_size_pixels_height` and/or `nwp_image_size_pixels_width` in `datamodule/example_configs.yaml`, unless transformations such as coarsening was applied (e. g. as for ECMWF data) + - `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `image_size_pixels_height` and/or `image_size_pixels_width` in `datamodule/configuration/site_example_configuration.yaml` for the NWP, unless transformations such as coarsening was applied (e. g. as for ECMWF data) 3. In `configs/local_trainer.yaml`: - set `accelerator: 0` if running on a system without a supported GPU @@ -216,7 +216,7 @@ defaults: - hydra: default.yaml ``` -Assuming you ran the `save_batches.py` script to generate some premade train and +Assuming you ran the `save_samples.py` script to generate some premade train and val data batches, you can now train PVNet by running: ``` diff --git a/pvnet/data/__init__.py b/pvnet/data/__init__.py index 5d763df0..98c955e1 100644 --- a/pvnet/data/__init__.py +++ b/pvnet/data/__init__.py @@ -1,2 +1,4 @@ """Data parts""" +from .site_datamodule import SiteDataModule +from .uk_regional_datamodule import DataModule from .utils import BatchSplitter diff --git a/pvnet/data/base.py b/pvnet/data/base.py deleted file mode 100644 index b53c4ee8..00000000 --- a/pvnet/data/base.py +++ /dev/null @@ -1,116 +0,0 @@ -""" Data module for pytorch lightning """ -from datetime import datetime - -from lightning.pytorch import LightningDataModule -from torch.utils.data import DataLoader - - -class BaseDataModule(LightningDataModule): - """Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.""" - - def __init__( - self, - configuration=None, - batch_size=16, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir=None, - shuffle_factor=100, - nwp_channels=None, - ): - """Datamodule for training pvnet architecture. - - Can also be used with pre-made batches if `batch_dir` is set. - - - Args: - configuration: Path to datapipe configuration file. - batch_size: Batch size. - num_workers: Number of workers to use in multiprocess batch loading. - prefetch_factor: Number of data will be prefetched at the end of each worker process. - train_period: Date range filter for train dataloader. - val_period: Date range filter for val dataloader. - test_period: Date range filter for test dataloader. - batch_dir: Path to the directory of pre-saved batches. Cannot be used together with - `configuration` or 'train/val/test_period'. - shuffle_factor: Number of presaved batches to be split and reshuffled to create returned - batches. A larger factor means on each epoch the batches will be more diverse but at - the cost of using more RAM. - nwp_channels: Number of NWP channels to use. If None, the all channels are used - """ - super().__init__() - self.configuration = configuration - self.batch_size = batch_size - self.batch_dir = batch_dir - self.shuffle_factor = shuffle_factor - self.nwp_channels = nwp_channels - - if not ((batch_dir is not None) ^ (configuration is not None)): - raise ValueError("Exactly one of `batch_dir` or `configuration` must be set.") - - if (nwp_channels is not None) and (batch_dir is None): - raise ValueError( - "In order for 'nwp_channels' to work, we need batch_dir. " - "Otherwise the nwp channels is one in the configuration" - ) - - if batch_dir is not None: - if any([period != [None, None] for period in [train_period, val_period, test_period]]): - raise ValueError("Cannot set `(train/val/test)_period` with presaved batches") - - self.train_period = [ - None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in train_period - ] - self.val_period = [ - None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in val_period - ] - self.test_period = [ - None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period - ] - - self._common_dataloader_kwargs = dict( - batch_size=None, # batched in datapipe step - sampler=None, - batch_sampler=None, - num_workers=num_workers, - collate_fn=None, - pin_memory=False, - drop_last=False, - timeout=0, - worker_init_fn=None, - prefetch_factor=prefetch_factor, - persistent_workers=False, - ) - - def _get_datapipe(self, start_time, end_time): - raise NotImplementedError - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - raise NotImplementedError - - def train_dataloader(self): - """Construct train dataloader""" - if self.batch_dir is not None: - datapipe = self._get_premade_batches_datapipe("train", shuffle=True) - else: - datapipe = self._get_datapipe(*self.train_period) - return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs) - - def val_dataloader(self): - """Construct val dataloader""" - if self.batch_dir is not None: - datapipe = self._get_premade_batches_datapipe("val") - else: - datapipe = self._get_datapipe(*self.val_period) - return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs) - - def test_dataloader(self): - """Construct test dataloader""" - if self.batch_dir is not None: - datapipe = self._get_premade_batches_datapipe("test") - else: - datapipe = self._get_datapipe(*self.test_period) - return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs) diff --git a/pvnet/data/base_datamodule.py b/pvnet/data/base_datamodule.py new file mode 100644 index 00000000..9e00d843 --- /dev/null +++ b/pvnet/data/base_datamodule.py @@ -0,0 +1,93 @@ +""" Data module for pytorch lightning """ +from lightning.pytorch import LightningDataModule +from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch +from ocf_datapipes.batch import ( + NumpyBatch, + TensorBatch, + batch_to_tensor, +) +from torch.utils.data import DataLoader, Dataset + + +def collate_fn(samples: list[NumpyBatch]) -> TensorBatch: + """Convert a list of NumpySample samples to a tensor batch""" + return batch_to_tensor(stack_np_samples_into_batch(samples)) + + +class BaseDataModule(LightningDataModule): + """Base Datamodule for training pvnet and using pvnet pipeline in ocf-data-sampler.""" + + def __init__( + self, + configuration: str | None = None, + sample_dir: str | None = None, + batch_size: int = 16, + num_workers: int = 0, + prefetch_factor: int | None = None, + train_period: list[str | None] = [None, None], + val_period: list[str | None] = [None, None], + ): + """Base Datamodule for training pvnet architecture. + + Can also be used with pre-made batches if `sample_dir` is set. + + Args: + configuration: Path to ocf-data-sampler configuration file. + sample_dir: Path to the directory of pre-saved samples. Cannot be used together with + `configuration` or '[train/val]_period'. + batch_size: Batch size. + num_workers: Number of workers to use in multiprocess batch loading. + prefetch_factor: Number of data will be prefetched at the end of each worker process. + train_period: Date range filter for train dataloader. + val_period: Date range filter for val dataloader. + + """ + super().__init__() + + if not ((sample_dir is not None) ^ (configuration is not None)): + raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.") + + if sample_dir is not None: + if any([period != [None, None] for period in [train_period, val_period]]): + raise ValueError("Cannot set `(train/val)_period` with presaved samples") + + self.configuration = configuration + self.sample_dir = sample_dir + self.train_period = train_period + self.val_period = val_period + + self._common_dataloader_kwargs = dict( + batch_size=batch_size, + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=prefetch_factor, + persistent_workers=False, + ) + + def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: + raise NotImplementedError + + def _get_premade_samples_dataset(self, subdir) -> Dataset: + raise NotImplementedError + + def train_dataloader(self) -> DataLoader: + """Construct train dataloader""" + if self.sample_dir is not None: + dataset = self._get_premade_samples_dataset("train") + else: + dataset = self._get_streamed_samples_dataset(*self.train_period) + return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs) + + def val_dataloader(self) -> DataLoader: + """Construct val dataloader""" + if self.sample_dir is not None: + dataset = self._get_premade_samples_dataset("val") + else: + dataset = self._get_streamed_samples_dataset(*self.val_period) + return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs) diff --git a/pvnet/data/pv_site_datamodule.py b/pvnet/data/pv_site_datamodule.py deleted file mode 100644 index 4c45eaec..00000000 --- a/pvnet/data/pv_site_datamodule.py +++ /dev/null @@ -1,67 +0,0 @@ -""" Data module for pytorch lightning """ -import glob - -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.pvnet_site import ( - pvnet_site_datapipe, - pvnet_site_netcdf_datapipe, - split_dataset_dict_dp, - uncombine_from_single_dataset, -) - -from pvnet.data.base import BaseDataModule - - -class PVSiteDataModule(BaseDataModule): - """Datamodule for training pvnet site and using pvnet site pipeline in `ocf_datapipes`.""" - - def _get_datapipe(self, start_time, end_time): - data_pipeline = pvnet_site_datapipe( - self.configuration, - start_time=start_time, - end_time=end_time, - ) - data_pipeline = data_pipeline.map(uncombine_from_single_dataset).map(split_dataset_dict_dp) - data_pipeline = data_pipeline.pvnet_site_convert_to_numpy_batch() - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - return data_pipeline - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) - data_pipeline = pvnet_site_netcdf_datapipe( - keys=["pv", "nwp"], # add other keys e.g. sat if used as input in site model - filenames=filenames, - ) - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - if shuffle: - data_pipeline = ( - data_pipeline.shuffle(buffer_size=100) - .sharding_filter() - # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.pv) - .shuffle(buffer_size=self.shuffle_factor * self.batch_size) - ) - else: - data_pipeline = ( - data_pipeline.sharding_filter() - # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.pv) - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - .set_length(int(len(filenames) / self.batch_size)) - ) - - return data_pipeline diff --git a/pvnet/data/site_datamodule.py b/pvnet/data/site_datamodule.py new file mode 100644 index 00000000..b55803ec --- /dev/null +++ b/pvnet/data/site_datamodule.py @@ -0,0 +1,80 @@ +""" Data module for pytorch lightning """ +from glob import glob + +import xarray as xr +from ocf_data_sampler.torch_datasets.site import SitesDataset, convert_netcdf_to_numpy_sample +from torch.utils.data import Dataset + +from pvnet.data.base_datamodule import BaseDataModule + + +class NetcdfPreMadeSamplesDataset(Dataset): + """Dataset to load pre-made netcdf samples""" + + def __init__( + self, + sample_dir, + ): + """Dataset to load pre-made netcdf samples + + Args: + sample_dir: Path to the directory of pre-saved samples. + """ + self.sample_paths = glob(f"{sample_dir}/*.nc") + + def __len__(self): + return len(self.sample_paths) + + def __getitem__(self, idx): + # open the sample + ds = xr.open_dataset(self.sample_paths[idx]) + + # convert to numpy + sample = convert_netcdf_to_numpy_sample(ds) + return sample + + +class SiteDataModule(BaseDataModule): + """Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.""" + + def __init__( + self, + configuration: str | None = None, + sample_dir: str | None = None, + batch_size: int = 16, + num_workers: int = 0, + prefetch_factor: int | None = None, + train_period: list[str | None] = [None, None], + val_period: list[str | None] = [None, None], + ): + """Datamodule for training pvnet architecture. + + Can also be used with pre-made batches if `sample_dir` is set. + + Args: + configuration: Path to datapipe configuration file. + sample_dir: Path to the directory of pre-saved samples. Cannot be used together with + `configuration` or '[train/val]_period'. + batch_size: Batch size. + num_workers: Number of workers to use in multiprocess batch loading. + prefetch_factor: Number of data will be prefetched at the end of each worker process. + train_period: Date range filter for train dataloader. + val_period: Date range filter for val dataloader. + + """ + super().__init__( + configuration=configuration, + sample_dir=sample_dir, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + train_period=train_period, + val_period=val_period, + ) + + def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: + return SitesDataset(self.configuration, start_time=start_time, end_time=end_time) + + def _get_premade_samples_dataset(self, subdir) -> Dataset: + split_dir = f"{self.sample_dir}/{subdir}" + return NetcdfPreMadeSamplesDataset(split_dir) diff --git a/pvnet/data/datamodule.py b/pvnet/data/uk_regional_datamodule.py similarity index 100% rename from pvnet/data/datamodule.py rename to pvnet/data/uk_regional_datamodule.py diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py deleted file mode 100644 index 0c11d31d..00000000 --- a/pvnet/data/wind_datamodule.py +++ /dev/null @@ -1,62 +0,0 @@ -""" Data module for pytorch lightning """ -import glob - -from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch -from ocf_datapipes.training.windnet import windnet_netcdf_datapipe - -from pvnet.data.base import BaseDataModule - - -class WindDataModule(BaseDataModule): - """Datamodule for training windnet and using windnet pipeline in `ocf_datapipes`.""" - - def _get_datapipe(self, start_time, end_time): - # TODO is this is not right, need to load full windnet pipeline - data_pipeline = windnet_netcdf_datapipe( - self.configuration, - keys=["wind", "nwp", "sensor"], - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - return data_pipeline - - def _get_premade_batches_datapipe(self, subdir, shuffle=False): - filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) - data_pipeline = windnet_netcdf_datapipe( - keys=["wind", "nwp", "sensor"], - filenames=filenames, - nwp_channels=self.nwp_channels, - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - ) - if shuffle: - data_pipeline = ( - data_pipeline.shuffle(buffer_size=100) - .sharding_filter() - # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.wind) - .shuffle(buffer_size=self.shuffle_factor * self.batch_size) - ) - else: - data_pipeline = ( - data_pipeline.sharding_filter() - # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.wind) - ) - - data_pipeline = ( - data_pipeline.batch(self.batch_size) - .map(stack_np_examples_into_batch) - .map(batch_to_tensor) - .set_length(int(len(filenames) / self.batch_size)) - ) - - return data_pipeline diff --git a/pyproject.toml b/pyproject.toml index b931d605..1fd0e2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ dynamic = ["version", "readme"] license={file="LICENCE"} dependencies = [ - "ocf_data_sampler==0.0.32", + "ocf_data_sampler==0.0.43", "ocf_datapipes>=3.3.34", "ocf_ml_metrics>=0.0.11", "numpy", diff --git a/scripts/save_samples.py b/scripts/save_samples.py index d38a45f9..edc6f035 100644 --- a/scripts/save_samples.py +++ b/scripts/save_samples.py @@ -44,7 +44,7 @@ import dask import hydra import torch -from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset +from ocf_data_sampler.torch_datasets import PVNetUKRegionalDataset, SitesDataset from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc from torch.utils.data import DataLoader, Dataset @@ -69,27 +69,29 @@ class SaveFuncFactory: """Factory for creating a function to save a sample to disk.""" - def __init__(self, save_dir: str, renewable: str = "pv"): + def __init__(self, save_dir: str, renewable: str = "pv_uk"): """Factory for creating a function to save a sample to disk.""" self.save_dir = save_dir self.renewable = renewable def __call__(self, sample, sample_num: int): """Save a sample to disk""" - if self.renewable == "pv": + if self.renewable == "pv_uk": torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt") - elif self.renewable in ["wind", "pv_india", "pv_site"]: - raise NotImplementedError + elif self.renewable == "site": + sample.to_netcdf(f"{self.save_dir}/{sample_num:08}.nc", mode="w", engine="h5netcdf") else: raise ValueError(f"Unknown renewable: {self.renewable}") -def get_dataset(config_path: str, start_time: str, end_time: str, renewable: str = "pv") -> Dataset: +def get_dataset( + config_path: str, start_time: str, end_time: str, renewable: str = "pv_uk" +) -> Dataset: """Get the dataset for the given renewable type.""" - if renewable == "pv": + if renewable == "pv_uk": dataset_cls = PVNetUKRegionalDataset - elif renewable in ["wind", "pv_india", "pv_site"]: - raise NotImplementedError + elif renewable == "site": + dataset_cls = SitesDataset else: raise ValueError(f"Unknown renewable: {renewable}") @@ -101,7 +103,7 @@ def save_samples_with_dataloader( save_dir: str, num_samples: int, dataloader_kwargs: dict, - renewable: str = "pv", + renewable: str = "pv_uk", ) -> None: """Save samples from a dataset using a dataloader.""" save_func = SaveFuncFactory(save_dir, renewable=renewable) diff --git a/tests/conftest.py b/tests/conftest.py index ba657af5..d5d9ae7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,8 +13,7 @@ from datetime import timedelta import pvnet -from pvnet.data.datamodule import DataModule -from pvnet.data.wind_datamodule import WindDataModule +from pvnet.data import DataModule, SiteDataModule import pvnet.models.multimodal.encoders.encoders3d import pvnet.models.multimodal.linear_networks.networks @@ -161,7 +160,7 @@ def sample_pv_batch(): @pytest.fixture() def sample_wind_batch(): - dm = WindDataModule( + dm = SiteDataModule( configuration=None, batch_size=2, num_workers=0, @@ -169,7 +168,7 @@ def sample_wind_batch(): train_period=[None, None], val_period=[None, None], test_period=[None, None], - batch_dir="tests/test_data/sample_wind_batches", + batch_dir="tests/test_data/sample_site_batches", ) batch = next(iter(dm.train_dataloader())) return batch diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 00b1705d..e9b8ce42 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,9 +1,5 @@ -import pytest -from pvnet.data.datamodule import DataModule -from pvnet.data.wind_datamodule import WindDataModule -from pvnet.data.pv_site_datamodule import PVSiteDataModule +from pvnet.data import DataModule, SiteDataModule import os -from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey def test_init(): @@ -18,57 +14,6 @@ def test_init(): ) -@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") -def test_wind_init(): - dm = WindDataModule( - configuration=None, - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir="tests/data/sample_batches", - ) - - -@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") -def test_wind_init_with_nwp_filter(): - dm = WindDataModule( - configuration=None, - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir="tests/test_data/sample_wind_batches", - nwp_channels={"ecmwf": ["t2m", "v200"]}, - ) - dataloader = iter(dm.train_dataloader()) - - batch = next(dataloader) - batch_channels = batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp_channel_names] - print(batch_channels) - for v in ["t2m", "v200"]: - assert v in batch_channels - assert batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp].shape[2] == 2 - - -@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet") -def test_pv_site_init(): - dm = PVSiteDataModule( - configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_batches/data_configuration.yaml", - batch_size=2, - num_workers=0, - prefetch_factor=None, - train_period=[None, None], - val_period=[None, None], - test_period=[None, None], - batch_dir=None, - ) - - def test_iter(): dm = DataModule( configuration=None, @@ -104,3 +49,27 @@ def test_iter_multiprocessing(): # Make sure we've served 2 batches assert served_batches == 2 + + +def test_site_init_sample_dir(): + dm = SiteDataModule( + configuration=None, + batch_size=2, + num_workers=0, + prefetch_factor=None, + train_period=[None, None], + val_period=[None, None], + sample_dir="tests/test_data/sample_site_batches", + ) + + +def test_site_init_config(): + dm = SiteDataModule( + configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_site_batches/data_configuration.yaml", + batch_size=2, + num_workers=0, + prefetch_factor=None, + train_period=[None, None], + val_period=[None, None], + sample_dir=None, + ) diff --git a/tests/models/multimodal/site_encoders/test_encoders.py b/tests/models/multimodal/site_encoders/test_encoders.py index 41969b22..48938bc3 100644 --- a/tests/models/multimodal/site_encoders/test_encoders.py +++ b/tests/models/multimodal/site_encoders/test_encoders.py @@ -42,13 +42,14 @@ def test_singleattentionnetwork_forward(sample_pv_batch, site_encoder_model_kwar ) -def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): - _test_model_forward( - sample_wind_batch, - SingleAttentionNetwork, - site_encoder_sensor_model_kwargs, - batch_size=2, - ) +# TODO once we have updated the sample batches for sites include this test +# def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs): +# _test_model_forward( +# sample_wind_batch, +# SingleAttentionNetwork, +# site_encoder_sensor_model_kwargs, +# batch_size=2, +# ) # Test model backward on all models diff --git a/tests/test_data/sample_wind_batches/data_configuration.yaml b/tests/test_data/sample_site_batches/data_configuration.yaml similarity index 100% rename from tests/test_data/sample_wind_batches/data_configuration.yaml rename to tests/test_data/sample_site_batches/data_configuration.yaml diff --git a/tests/test_data/sample_wind_batches/datamodule.yaml b/tests/test_data/sample_site_batches/datamodule.yaml similarity index 100% rename from tests/test_data/sample_wind_batches/datamodule.yaml rename to tests/test_data/sample_site_batches/datamodule.yaml diff --git a/tests/test_data/sample_wind_batches/train/000000.nc b/tests/test_data/sample_site_batches/train/000000.nc similarity index 100% rename from tests/test_data/sample_wind_batches/train/000000.nc rename to tests/test_data/sample_site_batches/train/000000.nc diff --git a/tests/test_data/sample_wind_batches/train/000001.nc b/tests/test_data/sample_site_batches/train/000001.nc similarity index 100% rename from tests/test_data/sample_wind_batches/train/000001.nc rename to tests/test_data/sample_site_batches/train/000001.nc