11""" Data module for pytorch lightning """
22
33from glob import glob
4-
4+ from typing import Type
55from lightning .pytorch import LightningDataModule
66from ocf_data_sampler .numpy_sample .collate import stack_np_samples_into_batch
7- from ocf_data_sampler .sample .base import SampleBase
8- from ocf_datapipes .batch import (
7+
8+ from ocf_data_sampler .sample .base import (
9+ SampleBase ,
910 NumpyBatch ,
1011 TensorBatch ,
1112 batch_to_tensor ,
@@ -26,8 +27,8 @@ class PremadeSamplesDataset(Dataset):
2627 sample_class: sample class type to use for save/load/to_numpy
2728 """
2829
29- def __init__ (self , sample_dir : str , sample_class : SampleBase ):
30- """Initialise PremadeSamplesDataset"""
30+ def __init__ (self , sample_dir : str , sample_class : Type [ SampleBase ] ):
31+ """Initialise PremadeSamplesDataset"""
3132 self .sample_paths = glob (f"{ sample_dir } /*" )
3233 self .sample_class = sample_class
3334
@@ -99,7 +100,8 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
99100 raise NotImplementedError
100101
101102 def _get_premade_samples_dataset (self , subdir ) -> Dataset :
102- raise NotImplementedError
103+ split_dir = f"{ self .sample_dir } /{ subdir } "
104+ return PremadeSamplesDataset (split_dir , None )
103105
104106 def train_dataloader (self ) -> DataLoader :
105107 """Construct train dataloader"""
0 commit comments