We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5b7711b commit 79dca28Copy full SHA for 79dca28
pvnet/data/base_datamodule.py
@@ -4,6 +4,7 @@
4
5
from lightning.pytorch import LightningDataModule
6
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
7
+from ocf_data_sampler.sample.base import SampleBase
8
from ocf_datapipes.batch import (
9
NumpyBatch,
10
TensorBatch,
@@ -22,9 +23,10 @@ class PremadeSamplesDataset(Dataset):
22
23
24
Args:
25
sample_dir: Path to the directory of pre-saved samples.
26
+ sample_class: sample class type to use for save/load/to_numpy
27
"""
28
- def __init__(self, sample_dir: str, sample_class):
29
+ def __init__(self, sample_dir: str, sample_class: SampleBase):
30
"""Initialise PremadeSamplesDataset"""
31
self.sample_paths = glob(f"{sample_dir}/*")
32
self.sample_class = sample_class
0 commit comments