diff --git a/pvnet/data/base_datamodule.py b/pvnet/data/base_datamodule.py index b130cb8a..64b49ae9 100644 --- a/pvnet/data/base_datamodule.py +++ b/pvnet/data/base_datamodule.py @@ -2,12 +2,12 @@ from glob import glob from typing import Type + from lightning.pytorch import LightningDataModule from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch - from ocf_data_sampler.sample.base import ( - SampleBase, NumpyBatch, + SampleBase, TensorBatch, batch_to_tensor, ) @@ -28,7 +28,7 @@ class PremadeSamplesDataset(Dataset): """ def __init__(self, sample_dir: str, sample_class: Type[SampleBase]): - """Initialise PremadeSamplesDataset""" + """Initialise PremadeSamplesDataset""" self.sample_paths = glob(f"{sample_dir}/*") self.sample_class = sample_class diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 8eb2c587..572cb5e2 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -72,4 +72,3 @@ def test_site_init_config(): val_period=[None, None], sample_dir=None, ) -