Skip to content

Commit 075ae31

Browse files
authored
Add sample saving for Site Dataset + presaved SiteDataset/DataModule for loading presaved samples (#290)
1 parent 430e1d7 commit 075ae31

File tree

17 files changed

+230
-329
lines changed

17 files changed

+230
-329
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,20 @@ This is also where you can update the train, val & test periods to cover the dat
145145

146146
### Running the batch creation script
147147

148-
Run the `save_batches.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example):
148+
Run the `save_samples.py` script to create batches with the parameters specified in the datamodule config (`streamed_batches.yaml` in this example):
149149

150150
```bash
151-
python scripts/save_batches.py
151+
python scripts/save_samples.py
152152
```
153153
PVNet uses
154154
[hydra](https://hydra.cc/) which enables us to pass variables via the command
155155
line that will override the configuration defined in the `./configs` directory, like this:
156156

157157
```bash
158-
python scripts/save_batches.py datamodule=streamed_batches datamodule.batch_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5
158+
python scripts/save_samples.py datamodule=streamed_batches datamodule.sample_output_dir="./output" datamodule.num_train_batches=10 datamodule.num_val_batches=5
159159
```
160160

161-
`scripts/save_batches.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder.
161+
`scripts/save_samples.py` needs a config under `PVNet/configs/datamodule`. You can adapt `streamed_batches.yaml` or create your own in the same folder.
162162

163163
If downloading private data from a GCP bucket make sure to authenticate gcloud (the public satellite data does not need authentication):
164164

@@ -197,7 +197,7 @@ Make sure to update the following config files before training your model:
197197
2. In `configs/model/local_multimodal.yaml`:
198198
- update the list of encoders to reflect the data sources you are using. If you are using different NWP sources, the encoders for these should follow the same structure with two important updates:
199199
- `in_channels`: number of variables your NWP source supplies
200-
- `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `nwp_image_size_pixels_height` and/or `nwp_image_size_pixels_width` in `datamodule/example_configs.yaml`, unless transformations such as coarsening was applied (e. g. as for ECMWF data)
200+
- `image_size_pixels`: spatial crop of your NWP data. It depends on the spatial resolution of your NWP; should match `image_size_pixels_height` and/or `image_size_pixels_width` in `datamodule/configuration/site_example_configuration.yaml` for the NWP, unless transformations such as coarsening was applied (e. g. as for ECMWF data)
201201
3. In `configs/local_trainer.yaml`:
202202
- set `accelerator: 0` if running on a system without a supported GPU
203203

@@ -216,7 +216,7 @@ defaults:
216216
- hydra: default.yaml
217217
```
218218

219-
Assuming you ran the `save_batches.py` script to generate some premade train and
219+
Assuming you ran the `save_samples.py` script to generate some premade train and
220220
val data batches, you can now train PVNet by running:
221221

222222
```

pvnet/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
"""Data parts"""
2+
from .site_datamodule import SiteDataModule
3+
from .uk_regional_datamodule import DataModule
24
from .utils import BatchSplitter

pvnet/data/base.py

Lines changed: 0 additions & 116 deletions
This file was deleted.

pvnet/data/base_datamodule.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
""" Data module for pytorch lightning """
2+
from lightning.pytorch import LightningDataModule
3+
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
4+
from ocf_datapipes.batch import (
5+
NumpyBatch,
6+
TensorBatch,
7+
batch_to_tensor,
8+
)
9+
from torch.utils.data import DataLoader, Dataset
10+
11+
12+
def collate_fn(samples: list[NumpyBatch]) -> TensorBatch:
13+
"""Convert a list of NumpySample samples to a tensor batch"""
14+
return batch_to_tensor(stack_np_samples_into_batch(samples))
15+
16+
17+
class BaseDataModule(LightningDataModule):
18+
"""Base Datamodule for training pvnet and using pvnet pipeline in ocf-data-sampler."""
19+
20+
def __init__(
21+
self,
22+
configuration: str | None = None,
23+
sample_dir: str | None = None,
24+
batch_size: int = 16,
25+
num_workers: int = 0,
26+
prefetch_factor: int | None = None,
27+
train_period: list[str | None] = [None, None],
28+
val_period: list[str | None] = [None, None],
29+
):
30+
"""Base Datamodule for training pvnet architecture.
31+
32+
Can also be used with pre-made batches if `sample_dir` is set.
33+
34+
Args:
35+
configuration: Path to ocf-data-sampler configuration file.
36+
sample_dir: Path to the directory of pre-saved samples. Cannot be used together with
37+
`configuration` or '[train/val]_period'.
38+
batch_size: Batch size.
39+
num_workers: Number of workers to use in multiprocess batch loading.
40+
prefetch_factor: Number of data will be prefetched at the end of each worker process.
41+
train_period: Date range filter for train dataloader.
42+
val_period: Date range filter for val dataloader.
43+
44+
"""
45+
super().__init__()
46+
47+
if not ((sample_dir is not None) ^ (configuration is not None)):
48+
raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.")
49+
50+
if sample_dir is not None:
51+
if any([period != [None, None] for period in [train_period, val_period]]):
52+
raise ValueError("Cannot set `(train/val)_period` with presaved samples")
53+
54+
self.configuration = configuration
55+
self.sample_dir = sample_dir
56+
self.train_period = train_period
57+
self.val_period = val_period
58+
59+
self._common_dataloader_kwargs = dict(
60+
batch_size=batch_size,
61+
sampler=None,
62+
batch_sampler=None,
63+
num_workers=num_workers,
64+
collate_fn=collate_fn,
65+
pin_memory=False,
66+
drop_last=False,
67+
timeout=0,
68+
worker_init_fn=None,
69+
prefetch_factor=prefetch_factor,
70+
persistent_workers=False,
71+
)
72+
73+
def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
74+
raise NotImplementedError
75+
76+
def _get_premade_samples_dataset(self, subdir) -> Dataset:
77+
raise NotImplementedError
78+
79+
def train_dataloader(self) -> DataLoader:
80+
"""Construct train dataloader"""
81+
if self.sample_dir is not None:
82+
dataset = self._get_premade_samples_dataset("train")
83+
else:
84+
dataset = self._get_streamed_samples_dataset(*self.train_period)
85+
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
86+
87+
def val_dataloader(self) -> DataLoader:
88+
"""Construct val dataloader"""
89+
if self.sample_dir is not None:
90+
dataset = self._get_premade_samples_dataset("val")
91+
else:
92+
dataset = self._get_streamed_samples_dataset(*self.val_period)
93+
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)

pvnet/data/pv_site_datamodule.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

0 commit comments

Comments
 (0)