Skip to content

Commit 65cc88a

Browse files
authored
Data sampler updates (#316)
Data sampler migration related updates
1 parent 13a24e0 commit 65cc88a

File tree

7 files changed

+7
-49
lines changed

7 files changed

+7
-49
lines changed

pvnet/data/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
"""Data parts"""
22
from .site_datamodule import SiteDataModule
33
from .uk_regional_datamodule import DataModule
4-
from .utils import BatchSplitter

pvnet/data/base_datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
from lightning.pytorch import LightningDataModule
66
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
7-
from ocf_data_sampler.sample.base import SampleBase
8-
from ocf_datapipes.batch import (
7+
from ocf_data_sampler.sample.base import (
98
NumpyBatch,
9+
SampleBase,
1010
TensorBatch,
1111
batch_to_tensor,
1212
)

pvnet/data/site_datamodule.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
""" Data module for pytorch lightning """
22

33
from ocf_data_sampler.sample.site import SiteSample
4-
from ocf_data_sampler.torch_datasets.datasets.site import (
5-
SitesDataset,
6-
)
4+
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset
75
from torch.utils.data import Dataset
86

97
from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset
108

119

1210
class SiteDataModule(BaseDataModule):
13-
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""
11+
"""Datamodule for training pvnet and using pvnet pipeline in `ocf-data-sampler`."""
1412

1513
def __init__(
1614
self,

pvnet/data/utils.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

pvnet/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import xarray as xr
1414
from lightning.pytorch.loggers import Logger
1515
from lightning.pytorch.utilities import rank_zero_only
16-
from ocf_datapipes.utils import Location
16+
from ocf_data_sampler.select.location import Location
1717
from omegaconf import DictConfig, OmegaConf
1818

1919

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dynamic = ["version", "readme"]
66
license={file="LICENCE"}
77

88
dependencies = [
9-
"ocf_data_sampler==0.1.2",
9+
"ocf_data_sampler==0.1.7",
1010
"ocf_datapipes>=3.3.34",
1111
"ocf_ml_metrics>=0.0.11",
1212
"numpy",

tests/data/test_datamodule.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def test_iter():
2626
val_period=[None, None],
2727
)
2828

29-
batch = next(iter(dm.train_dataloader()))
30-
3129

3230
def test_iter_multiprocessing():
3331
dm = DataModule(
@@ -55,12 +53,12 @@ def test_iter_multiprocessing():
5553
def test_site_init_sample_dir():
5654
dm = SiteDataModule(
5755
configuration=None,
56+
sample_dir="tests/test_data/presaved_site_samples",
5857
batch_size=2,
5958
num_workers=0,
6059
prefetch_factor=None,
6160
train_period=[None, None],
6261
val_period=[None, None],
63-
sample_dir="tests/test_data/presaved_site_samples",
6462
)
6563

6664

@@ -74,21 +72,3 @@ def test_site_init_config():
7472
val_period=[None, None],
7573
sample_dir=None,
7674
)
77-
78-
79-
def test_worker_configuration():
80-
dm = DataModule(
81-
sample_dir="tests/test_data/presaved_samples_uk_regional",
82-
batch_size=2,
83-
num_workers=4,
84-
prefetch_factor=2,
85-
)
86-
87-
# Iterate through dataloader - assert multi-processing functions fine
88-
batches_processed = 0
89-
for batch in dm.train_dataloader():
90-
batches_processed += 1
91-
if batches_processed >= 5:
92-
break
93-
94-
assert batches_processed > 0, "No batches were processed"

0 commit comments

Comments
 (0)