Skip to content

Commit 79dca28

Browse files
authored
Update PremadeSampleDataset signature (#312)
1 parent 5b7711b commit 79dca28

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pvnet/data/base_datamodule.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from lightning.pytorch import LightningDataModule
66
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
7+
from ocf_data_sampler.sample.base import SampleBase
78
from ocf_datapipes.batch import (
89
NumpyBatch,
910
TensorBatch,
@@ -22,9 +23,10 @@ class PremadeSamplesDataset(Dataset):
2223
2324
Args:
2425
sample_dir: Path to the directory of pre-saved samples.
26+
sample_class: sample class type to use for save/load/to_numpy
2527
"""
2628

27-
def __init__(self, sample_dir: str, sample_class):
29+
def __init__(self, sample_dir: str, sample_class: SampleBase):
2830
"""Initialise PremadeSamplesDataset"""
2931
self.sample_paths = glob(f"{sample_dir}/*")
3032
self.sample_class = sample_class

0 commit comments

Comments
 (0)