diff --git a/pvnet/data/base_datamodule.py b/pvnet/data/base_datamodule.py index 9e00d843..68d69713 100644 --- a/pvnet/data/base_datamodule.py +++ b/pvnet/data/base_datamodule.py @@ -1,4 +1,6 @@ """ Data module for pytorch lightning """ + +from glob import glob from lightning.pytorch import LightningDataModule from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch from ocf_datapipes.batch import ( @@ -14,6 +16,24 @@ def collate_fn(samples: list[NumpyBatch]) -> TensorBatch: return batch_to_tensor(stack_np_samples_into_batch(samples)) +class PremadeSamplesDataset(Dataset): + """Dataset to load samples from + + Args: + sample_dir: Path to the directory of pre-saved samples. + """ + def __init__(self, sample_dir: str, sample_class): + self.sample_paths = glob(f"{sample_dir}/*") + self.sample_class = sample_class + + def __len__(self): + return len(self.sample_paths) + + def __getitem__(self, idx): + sample = self.sample_class.load(self.sample_paths[idx]) + return sample.to_numpy() + + class BaseDataModule(LightningDataModule): """Base Datamodule for training pvnet and using pvnet pipeline in ocf-data-sampler.""" diff --git a/pvnet/data/site_datamodule.py b/pvnet/data/site_datamodule.py index 563afb68..bc60939e 100644 --- a/pvnet/data/site_datamodule.py +++ b/pvnet/data/site_datamodule.py @@ -1,35 +1,14 @@ """ Data module for pytorch lightning """ -from glob import glob -from ocf_data_sampler.sample.site import SiteSample +import xarray as xr from ocf_data_sampler.torch_datasets.datasets.site import ( SitesDataset, + convert_netcdf_to_numpy_sample, ) +from ocf_data_sampler.sample.site import SiteSample from torch.utils.data import Dataset -from pvnet.data.base_datamodule import BaseDataModule - - -class NetcdfPreMadeSamplesDataset(Dataset): - """Dataset to load pre-made netcdf samples""" - - def __init__( - self, - sample_dir, - ): - """Dataset to load pre-made netcdf samples - - Args: - sample_dir: Path to the directory of pre-saved samples. - """ - self.sample_paths = glob(f"{sample_dir}/*.nc") - - def __len__(self): - return len(self.sample_paths) - - def __getitem__(self, idx): - sample = SiteSample.load(self.sample_paths[idx]) - return sample.to_numpy() +from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset class SiteDataModule(BaseDataModule): @@ -75,4 +54,4 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: def _get_premade_samples_dataset(self, subdir) -> Dataset: split_dir = f"{self.sample_dir}/{subdir}" - return NetcdfPreMadeSamplesDataset(split_dir) + return PremadeSamplesDataset(split_dir, SiteSample) diff --git a/pvnet/data/uk_regional_datamodule.py b/pvnet/data/uk_regional_datamodule.py index 463bed19..3de9066d 100644 --- a/pvnet/data/uk_regional_datamodule.py +++ b/pvnet/data/uk_regional_datamodule.py @@ -1,31 +1,10 @@ """ Data module for pytorch lightning """ -from glob import glob from ocf_data_sampler.sample.uk_regional import UKRegionalSample from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset from torch.utils.data import Dataset -from pvnet.data.base_datamodule import BaseDataModule - - -class NumpybatchPremadeSamplesDataset(Dataset): - """Dataset to load NumpyBatch samples""" - - def __init__(self, sample_dir): - """Dataset to load NumpyBatch samples - - Args: - sample_dir: Path to the directory of pre-saved samples. - """ - self.sample_paths = glob(f"{sample_dir}/*.pt") - - def __len__(self): - return len(self.sample_paths) - - def __getitem__(self, idx): - # Returns a dict of tensors - sample = UKRegionalSample.load(self.sample_paths[idx]) - return sample.to_numpy() +from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset class DataModule(BaseDataModule): @@ -71,4 +50,5 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset: def _get_premade_samples_dataset(self, subdir) -> Dataset: split_dir = f"{self.sample_dir}/{subdir}" - return NumpybatchPremadeSamplesDataset(split_dir) + # Returns a dict of tensors + return PremadeSamplesDataset(split_dir, UKRegionalSample)