Skip to content

Commit db74cd4

Browse files
committed
Fix missign datamodule for PV site
1 parent 0c6b9c0 commit db74cd4

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed

pvnet/data/pv_site_datamodule.py

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
""" Data module for pytorch lightning """
2+
import glob
3+
4+
from lightning.pytorch import LightningDataModule
5+
from ocf_datapipes.batch import BatchKey, stack_np_examples_into_batch
6+
from ocf_datapipes.training.pvnet_site import pvnet_site_netcdf_datapipe
7+
from torch.utils.data import DataLoader
8+
9+
from pvnet.data.utils import batch_to_tensor
10+
11+
12+
class WindDataModule(LightningDataModule):
13+
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""
14+
15+
def __init__(
16+
self,
17+
configuration=None,
18+
batch_size=16,
19+
num_workers=0,
20+
prefetch_factor=None,
21+
train_period=[None, None],
22+
val_period=[None, None],
23+
test_period=[None, None],
24+
batch_dir=None,
25+
):
26+
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.
27+
28+
Can also be used with pre-made batches if `batch_dir` is set.
29+
30+
31+
Args:
32+
configuration: Path to datapipe configuration file.
33+
batch_size: Batch size.
34+
num_workers: Number of workers to use in multiprocess batch loading.
35+
prefetch_factor: Number of data will be prefetched at the end of each worker process.
36+
train_period: Date range filter for train dataloader.
37+
val_period: Date range filter for val dataloader.
38+
test_period: Date range filter for test dataloader.
39+
batch_dir: Path to the directory of pre-saved batches. Cannot be used together with
40+
'train/val/test_period'.
41+
42+
"""
43+
super().__init__()
44+
self.configuration = configuration
45+
self.batch_size = batch_size
46+
self.batch_dir = batch_dir
47+
48+
# if batch_dir is not None:
49+
# if any([period != [None, None] for period in [train_period, val_period, test_period]]):
50+
# raise ValueError("Cannot set `(train/val/test)_period` with presaved batches")
51+
52+
self.train_period = [None, None]
53+
# None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in train_period
54+
# ]
55+
self.val_period = [None, None]
56+
# None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in val_period
57+
# ]
58+
self.test_period = [None, None]
59+
# None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period
60+
# ]
61+
62+
self._common_dataloader_kwargs = dict(
63+
batch_size=None, # batched in datapipe step
64+
sampler=None,
65+
batch_sampler=None,
66+
num_workers=num_workers,
67+
collate_fn=None,
68+
pin_memory=False,
69+
drop_last=False,
70+
timeout=0,
71+
worker_init_fn=None,
72+
prefetch_factor=prefetch_factor,
73+
persistent_workers=False,
74+
)
75+
76+
def _get_datapipe(self, start_time, end_time):
77+
data_pipeline = pvnet_site_netcdf_datapipe(
78+
self.configuration,
79+
keys=["pv", "nwp"],
80+
)
81+
82+
data_pipeline = (
83+
data_pipeline.batch(self.batch_size)
84+
.map(stack_np_examples_into_batch)
85+
.map(batch_to_tensor)
86+
)
87+
return data_pipeline
88+
89+
def _get_premade_batches_datapipe(self, subdir, shuffle=False):
90+
filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc"))
91+
data_pipeline = pvnet_site_netcdf_datapipe(
92+
config_filename=self.configuration,
93+
keys=["pv", "nwp"],
94+
filenames=filenames,
95+
)
96+
data_pipeline = (
97+
data_pipeline.batch(self.batch_size)
98+
.map(stack_np_examples_into_batch)
99+
.map(batch_to_tensor)
100+
)
101+
if shuffle:
102+
data_pipeline = (
103+
data_pipeline.shuffle(buffer_size=100)
104+
.sharding_filter()
105+
# Split the batches and reshuffle them to be combined into new batches
106+
.split_batches(splitting_key=BatchKey.sensor)
107+
.shuffle(buffer_size=100 * self.batch_size)
108+
)
109+
else:
110+
data_pipeline = (
111+
data_pipeline.sharding_filter()
112+
# Split the batches so we can use any batch-size
113+
.split_batches(splitting_key=BatchKey.sensor)
114+
)
115+
116+
data_pipeline = (
117+
data_pipeline.batch(self.batch_size)
118+
.map(stack_np_examples_into_batch)
119+
.map(batch_to_tensor)
120+
.set_length(int(len(filenames) / self.batch_size))
121+
)
122+
123+
return data_pipeline
124+
125+
def train_dataloader(self):
126+
"""Construct train dataloader"""
127+
if self.batch_dir is not None:
128+
datapipe = self._get_premade_batches_datapipe("train", shuffle=True)
129+
else:
130+
datapipe = self._get_datapipe(*self.train_period)
131+
return DataLoader(datapipe, shuffle=True, **self._common_dataloader_kwargs)
132+
133+
def val_dataloader(self):
134+
"""Construct val dataloader"""
135+
if self.batch_dir is not None:
136+
datapipe = self._get_premade_batches_datapipe("val")
137+
else:
138+
datapipe = self._get_datapipe(*self.val_period)
139+
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)
140+
141+
def test_dataloader(self):
142+
"""Construct test dataloader"""
143+
if self.batch_dir is not None:
144+
datapipe = self._get_premade_batches_datapipe("test")
145+
else:
146+
datapipe = self._get_datapipe(*self.test_period)
147+
return DataLoader(datapipe, shuffle=False, **self._common_dataloader_kwargs)

0 commit comments

Comments
 (0)