Skip to content

Commit

Permalink
base, site and uk_regional _datamodule related updates
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Jan 30, 2025
1 parent 8fab7ab commit 8bc9d05
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 49 deletions.
20 changes: 20 additions & 0 deletions pvnet/data/base_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
""" Data module for pytorch lightning """

from glob import glob
from lightning.pytorch import LightningDataModule
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
from ocf_datapipes.batch import (
Expand All @@ -14,6 +16,24 @@ def collate_fn(samples: list[NumpyBatch]) -> TensorBatch:
return batch_to_tensor(stack_np_samples_into_batch(samples))


class PremadeSamplesDataset(Dataset):
"""Dataset to load samples from
Args:
sample_dir: Path to the directory of pre-saved samples.
"""
def __init__(self, sample_dir: str, sample_class):
self.sample_paths = glob(f"{sample_dir}/*")
self.sample_class = sample_class

def __len__(self):
return len(self.sample_paths)

def __getitem__(self, idx):
sample = self.sample_class.load(self.sample_paths[idx])
return sample.to_numpy()


class BaseDataModule(LightningDataModule):
"""Base Datamodule for training pvnet and using pvnet pipeline in ocf-data-sampler."""

Expand Down
31 changes: 5 additions & 26 deletions pvnet/data/site_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,14 @@
""" Data module for pytorch lightning """
from glob import glob

from ocf_data_sampler.sample.site import SiteSample
import xarray as xr
from ocf_data_sampler.torch_datasets.datasets.site import (
SitesDataset,
convert_netcdf_to_numpy_sample,
)
from ocf_data_sampler.sample.site import SiteSample
from torch.utils.data import Dataset

from pvnet.data.base_datamodule import BaseDataModule


class NetcdfPreMadeSamplesDataset(Dataset):
"""Dataset to load pre-made netcdf samples"""

def __init__(
self,
sample_dir,
):
"""Dataset to load pre-made netcdf samples
Args:
sample_dir: Path to the directory of pre-saved samples.
"""
self.sample_paths = glob(f"{sample_dir}/*.nc")

def __len__(self):
return len(self.sample_paths)

def __getitem__(self, idx):
sample = SiteSample.load(self.sample_paths[idx])
return sample.to_numpy()
from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset


class SiteDataModule(BaseDataModule):
Expand Down Expand Up @@ -75,4 +54,4 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:

def _get_premade_samples_dataset(self, subdir) -> Dataset:
split_dir = f"{self.sample_dir}/{subdir}"
return NetcdfPreMadeSamplesDataset(split_dir)
return PremadeSamplesDataset(split_dir, SiteSample)
26 changes: 3 additions & 23 deletions pvnet/data/uk_regional_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
""" Data module for pytorch lightning """
from glob import glob

from ocf_data_sampler.sample.uk_regional import UKRegionalSample
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from torch.utils.data import Dataset

from pvnet.data.base_datamodule import BaseDataModule


class NumpybatchPremadeSamplesDataset(Dataset):
"""Dataset to load NumpyBatch samples"""

def __init__(self, sample_dir):
"""Dataset to load NumpyBatch samples
Args:
sample_dir: Path to the directory of pre-saved samples.
"""
self.sample_paths = glob(f"{sample_dir}/*.pt")

def __len__(self):
return len(self.sample_paths)

def __getitem__(self, idx):
# Returns a dict of tensors
sample = UKRegionalSample.load(self.sample_paths[idx])
return sample.to_numpy()
from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset


class DataModule(BaseDataModule):
Expand Down Expand Up @@ -71,4 +50,5 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:

def _get_premade_samples_dataset(self, subdir) -> Dataset:
split_dir = f"{self.sample_dir}/{subdir}"
return NumpybatchPremadeSamplesDataset(split_dir)
# Returns a dict of tensors
return PremadeSamplesDataset(split_dir, UKRegionalSample)

0 comments on commit 8bc9d05

Please sign in to comment.