Skip to content

Commit b2e5104

Browse files
committed
Test push
1 parent 3d8a9fe commit b2e5104

File tree

5 files changed

+13
-32
lines changed

5 files changed

+13
-32
lines changed

pvnet/data/base_datamodule.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
""" Data module for pytorch lightning """
22

33
from glob import glob
4-
4+
from typing import Type
55
from lightning.pytorch import LightningDataModule
66
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
7-
from ocf_data_sampler.sample.base import SampleBase
8-
from ocf_datapipes.batch import (
7+
8+
from ocf_data_sampler.sample.base import (
9+
SampleBase,
910
NumpyBatch,
1011
TensorBatch,
1112
batch_to_tensor,
@@ -26,8 +27,8 @@ class PremadeSamplesDataset(Dataset):
2627
sample_class: sample class type to use for save/load/to_numpy
2728
"""
2829

29-
def __init__(self, sample_dir: str, sample_class: SampleBase):
30-
"""Initialise PremadeSamplesDataset"""
30+
def __init__(self, sample_dir: str, sample_class: Type[SampleBase]):
31+
"""Initialise PremadeSamplesDataset"""
3132
self.sample_paths = glob(f"{sample_dir}/*")
3233
self.sample_class = sample_class
3334

@@ -99,7 +100,8 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
99100
raise NotImplementedError
100101

101102
def _get_premade_samples_dataset(self, subdir) -> Dataset:
102-
raise NotImplementedError
103+
split_dir = f"{self.sample_dir}/{subdir}"
104+
return PremadeSamplesDataset(split_dir, None)
103105

104106
def train_dataloader(self) -> DataLoader:
105107
"""Construct train dataloader"""

pvnet/data/site_datamodule.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
""" Data module for pytorch lightning """
22

33
from ocf_data_sampler.sample.site import SiteSample
4-
from ocf_data_sampler.torch_datasets.datasets.site import (
5-
SitesDataset,
6-
)
4+
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
75
from torch.utils.data import Dataset
86

97
from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset
108

119

1210
class SiteDataModule(BaseDataModule):
13-
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""
11+
"""Datamodule for training pvnet and using pvnet pipeline in `ocf-data-sampler`."""
1412

1513
def __init__(
1614
self,

pvnet/data/uk_regional_datamodule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
""" Data module for pytorch lightning """
22

33
from ocf_data_sampler.sample.uk_regional import UKRegionalSample
4-
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
4+
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
55
from torch.utils.data import Dataset
66

77
from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dynamic = ["version", "readme"]
66
license={file="LICENCE"}
77

88
dependencies = [
9-
"ocf_data_sampler==0.0.54",
9+
"ocf_data_sampler==0.1.7",
1010
"ocf_datapipes>=3.3.34",
1111
"ocf_ml_metrics>=0.0.11",
1212
"numpy",

tests/data/test_datamodule.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def test_iter():
2626
val_period=[None, None],
2727
)
2828

29-
batch = next(iter(dm.train_dataloader()))
30-
3129

3230
def test_iter_multiprocessing():
3331
dm = DataModule(
@@ -55,12 +53,12 @@ def test_iter_multiprocessing():
5553
def test_site_init_sample_dir():
5654
dm = SiteDataModule(
5755
configuration=None,
56+
sample_dir="tests/test_data/presaved_site_samples",
5857
batch_size=2,
5958
num_workers=0,
6059
prefetch_factor=None,
6160
train_period=[None, None],
6261
val_period=[None, None],
63-
sample_dir="tests/test_data/presaved_site_samples",
6462
)
6563

6664

@@ -75,20 +73,3 @@ def test_site_init_config():
7573
sample_dir=None,
7674
)
7775

78-
79-
def test_worker_configuration():
80-
dm = DataModule(
81-
sample_dir="tests/test_data/presaved_samples_uk_regional",
82-
batch_size=2,
83-
num_workers=4,
84-
prefetch_factor=2,
85-
)
86-
87-
# Iterate through dataloader - assert multi-processing functions fine
88-
batches_processed = 0
89-
for batch in dm.train_dataloader():
90-
batches_processed += 1
91-
if batches_processed >= 5:
92-
break
93-
94-
assert batches_processed > 0, "No batches were processed"

0 commit comments

Comments
 (0)