Skip to content

Commit

Permalink
Test push
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Feb 13, 2025
1 parent 3d8a9fe commit b2e5104
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 32 deletions.
14 changes: 8 additions & 6 deletions pvnet/data/base_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
""" Data module for pytorch lightning """

from glob import glob

from typing import Type
from lightning.pytorch import LightningDataModule
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
from ocf_data_sampler.sample.base import SampleBase
from ocf_datapipes.batch import (

from ocf_data_sampler.sample.base import (
SampleBase,
NumpyBatch,
TensorBatch,
batch_to_tensor,
Expand All @@ -26,8 +27,8 @@ class PremadeSamplesDataset(Dataset):
sample_class: sample class type to use for save/load/to_numpy
"""

def __init__(self, sample_dir: str, sample_class: SampleBase):
"""Initialise PremadeSamplesDataset"""
def __init__(self, sample_dir: str, sample_class: Type[SampleBase]):
"""Initialise PremadeSamplesDataset"""
self.sample_paths = glob(f"{sample_dir}/*")
self.sample_class = sample_class

Expand Down Expand Up @@ -99,7 +100,8 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
raise NotImplementedError

def _get_premade_samples_dataset(self, subdir) -> Dataset:
raise NotImplementedError
split_dir = f"{self.sample_dir}/{subdir}"
return PremadeSamplesDataset(split_dir, None)

def train_dataloader(self) -> DataLoader:
"""Construct train dataloader"""
Expand Down
6 changes: 2 additions & 4 deletions pvnet/data/site_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
""" Data module for pytorch lightning """

from ocf_data_sampler.sample.site import SiteSample
from ocf_data_sampler.torch_datasets.datasets.site import (
SitesDataset,
)
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
from torch.utils.data import Dataset

from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset


class SiteDataModule(BaseDataModule):
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""
"""Datamodule for training pvnet and using pvnet pipeline in `ocf-data-sampler`."""

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion pvnet/data/uk_regional_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
""" Data module for pytorch lightning """

from ocf_data_sampler.sample.uk_regional import UKRegionalSample
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
from torch.utils.data import Dataset

from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dynamic = ["version", "readme"]
license={file="LICENCE"}

dependencies = [
"ocf_data_sampler==0.0.54",
"ocf_data_sampler==0.1.7",
"ocf_datapipes>=3.3.34",
"ocf_ml_metrics>=0.0.11",
"numpy",
Expand Down
21 changes: 1 addition & 20 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def test_iter():
val_period=[None, None],
)

batch = next(iter(dm.train_dataloader()))


def test_iter_multiprocessing():
dm = DataModule(
Expand Down Expand Up @@ -55,12 +53,12 @@ def test_iter_multiprocessing():
def test_site_init_sample_dir():
dm = SiteDataModule(
configuration=None,
sample_dir="tests/test_data/presaved_site_samples",
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
sample_dir="tests/test_data/presaved_site_samples",
)


Expand All @@ -75,20 +73,3 @@ def test_site_init_config():
sample_dir=None,
)


def test_worker_configuration():
dm = DataModule(
sample_dir="tests/test_data/presaved_samples_uk_regional",
batch_size=2,
num_workers=4,
prefetch_factor=2,
)

# Iterate through dataloader - assert multi-processing functions fine
batches_processed = 0
for batch in dm.train_dataloader():
batches_processed += 1
if batches_processed >= 5:
break

assert batches_processed > 0, "No batches were processed"

0 comments on commit b2e5104

Please sign in to comment.