Skip to content

Commit db81147

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4f0a3dd commit db81147

File tree

5 files changed

+59
-83
lines changed

5 files changed

+59
-83
lines changed

pvnet/data/datamodule.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
""" Data module for pytorch lightning """
2-
from datetime import datetime
32
from glob import glob
43

5-
from lightning.pytorch import LightningDataModule
6-
from torch.utils.data import Dataset, DataLoader
74
import torch
8-
9-
from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch, NumpyBatch
10-
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import (
11-
PVNetUKRegionalDataset
12-
)
5+
from lightning.pytorch import LightningDataModule
6+
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
7+
from ocf_datapipes.batch import NumpyBatch, batch_to_tensor, stack_np_examples_into_batch
8+
from torch.utils.data import DataLoader, Dataset
139

1410

1511
def fill_nans_in_arrays(batch):
@@ -29,30 +25,28 @@ def fill_nans_in_arrays(batch):
2925
return batch
3026

3127

32-
3328
class NumpybatchPremadeSamplesDataset(Dataset):
3429
"""Dataset to load NumpyBatch samples"""
35-
30+
3631
def __init__(self, sample_dir):
3732
"""Dataset to load NumpyBatch samples
38-
33+
3934
Args:
4035
sample_dir: Path to the directory of pre-saved samples.
4136
"""
4237
self.sample_paths = glob(f"{sample_dir}/*.pt")
43-
44-
38+
4539
def __len__(self):
4640
return len(self.sample_paths)
47-
41+
4842
def __getitem__(self, idx):
4943
return fill_nans_in_arrays(torch.load(self.sample_paths[idx]))
50-
44+
5145

5246
def collate_fn(samples: list[NumpyBatch]):
5347
"""Convert a list of NumpyBatch samples to a tensor batch"""
5448
return batch_to_tensor(stack_np_examples_into_batch(samples))
55-
49+
5650

5751
class DataModule(LightningDataModule):
5852
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""
@@ -64,9 +58,8 @@ def __init__(
6458
batch_size: int = 16,
6559
num_workers: int = 0,
6660
prefetch_factor: int | None = None,
67-
train_period: list[str|None] = [None, None],
68-
val_period: list[str|None] = [None, None],
69-
61+
train_period: list[str | None] = [None, None],
62+
val_period: list[str | None] = [None, None],
7063
):
7164
"""Datamodule for training pvnet architecture.
7265
@@ -85,7 +78,6 @@ def __init__(
8578
"""
8679
super().__init__()
8780

88-
8981
if not ((sample_dir is not None) ^ (configuration is not None)):
9082
raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.")
9183

@@ -118,21 +110,19 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
118110
def _get_premade_samples_dataset(self, subdir) -> Dataset:
119111
split_dir = f"{self.sample_dir}/{subdir}"
120112
return NumpybatchPremadeSamplesDataset(split_dir)
121-
113+
122114
def train_dataloader(self) -> DataLoader:
123115
"""Construct train dataloader"""
124116
if self.sample_dir is not None:
125117
dataset = self._get_premade_samples_dataset("train")
126118
else:
127119
dataset = self._get_streamed_samples_dataset(*self.train_period)
128120
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
129-
121+
130122
def val_dataloader(self) -> DataLoader:
131123
"""Construct val dataloader"""
132124
if self.sample_dir is not None:
133125
dataset = self._get_premade_samples_dataset("val")
134126
else:
135127
dataset = self._get_streamed_samples_dataset(*self.val_period)
136128
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
137-
138-

pvnet/models/base_model.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,7 @@
1818
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
1919
from huggingface_hub.file_download import hf_hub_download
2020
from huggingface_hub.hf_api import HfApi
21-
22-
from ocf_datapipes.batch import BatchKey
23-
from ocf_datapipes.batch import copy_batch_to_device
24-
25-
from ocf_ml_metrics.evaluation.evaluation import evaluation
21+
from ocf_datapipes.batch import BatchKey, copy_batch_to_device
2622

2723
from pvnet.models.utils import (
2824
BatchAccumulator,
@@ -32,8 +28,6 @@
3228
from pvnet.optimizers import AbstractOptimizer
3329
from pvnet.utils import plot_batch_forecasts
3430

35-
36-
3731
DATA_CONFIG_NAME = "data_config.yaml"
3832

3933

@@ -239,13 +233,11 @@ def get_data_config(
239233
)
240234

241235
return data_config_file
242-
243-
236+
244237
def _save_pretrained(self, save_directory: Path) -> None:
245238
"""Save weights from a Pytorch model to a local directory."""
246239
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
247240
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)
248-
249241

250242
def save_pretrained(
251243
self,
@@ -416,14 +408,14 @@ def __init__(
416408
self.num_output_features = self.forecast_len * len(self.output_quantiles)
417409
else:
418410
self.num_output_features = self.forecast_len
419-
411+
420412
# save all validation results to array, so we can save these to weights n biases
421413
self.validation_epoch_results = []
422414

423415
def transfer_batch_to_device(self, batch, device, dataloader_idx):
424416
"""Method to move custom batches to a given device"""
425417
return copy_batch_to_device(batch, device)
426-
418+
427419
def _quantiles_to_prediction(self, y_quantiles):
428420
"""
429421
Convert network prediction into a point prediction.
@@ -465,7 +457,7 @@ def _calculate_quantile_loss(self, y_quantiles, y):
465457
errors = y - y_quantiles[..., i]
466458
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
467459
losses = 2 * torch.cat(losses, dim=2)
468-
460+
469461
return losses.mean()
470462

471463
def _calculate_common_losses(self, y, y_hat):
@@ -659,7 +651,7 @@ def validation_step(self, batch: dict, batch_idx):
659651
accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches
660652

661653
y_hat = self(batch)
662-
654+
663655
y = batch[self._target_key][:, -self.forecast_len :]
664656

665657
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:

pvnet/models/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Utility functions"""
22

33
import logging
4-
import math
5-
from typing import Optional
64

75
import numpy as np
86
import torch

pvnet/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import lightning.pytorch as pl
88
import matplotlib.pyplot as plt
9-
import numpy as np
109
import pandas as pd
1110
import pylab
1211
import rich.syntax
@@ -16,7 +15,6 @@
1615
from lightning.pytorch.utilities import rank_zero_only
1716
from ocf_datapipes.batch import BatchKey
1817
from ocf_datapipes.utils import Location
19-
from ocf_datapipes.utils.geospatial import osgb_to_lon_lat
2018
from omegaconf import DictConfig, OmegaConf
2119

2220

@@ -322,4 +320,4 @@ def _get_numpy(key):
322320
plt.suptitle(title)
323321
plt.tight_layout()
324322

325-
return fig
323+
return fig

scripts/save_samples.py

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,41 +20,38 @@
2020
```
2121
if wanting to override these values for example
2222
"""
23-
23+
2424
# Ensure this block of code runs only in the main process to avoid issues with worker processes.
2525
if __name__ == "__main__":
2626
import torch.multiprocessing as mp
27-
28-
# Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be
29-
# compatible with dask's multiprocessing.
27+
28+
# Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be
29+
# compatible with dask's multiprocessing.
3030
mp.set_start_method("forkserver")
31-
32-
# Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is
33-
# important because libraries like Zarr may open many files, which can exhaust the file
31+
32+
# Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is
33+
# important because libraries like Zarr may open many files, which can exhaust the file
3434
# descriptor limit if too many workers are used.
35-
mp.set_sharing_strategy('file_system')
35+
mp.set_sharing_strategy("file_system")
3636

3737

38+
import logging
3839
import os
39-
import sys
4040
import shutil
41-
import logging
41+
import sys
4242
import warnings
4343

44+
import dask
4445
import hydra
46+
import torch
47+
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
4548
from omegaconf import DictConfig, OmegaConf
4649
from sqlalchemy import exc as sa_exc
50+
from torch.utils.data import DataLoader, Dataset
4751
from tqdm import tqdm
4852

49-
import torch
50-
from torch.utils.data import Dataset, DataLoader
51-
52-
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
53-
5453
from pvnet.utils import print_config
5554

56-
import dask
57-
5855
dask.config.set(scheduler="threads", num_workers=4)
5956

6057

@@ -71,6 +68,7 @@
7168

7269
class SaveFuncFactory:
7370
"""Factory for creating a function to save a sample to disk."""
71+
7472
def __init__(self, save_dir: str, renewable: str = "pv"):
7573
self.save_dir = save_dir
7674
self.renewable = renewable
@@ -86,22 +84,22 @@ def __call__(self, sample, sample_num: int):
8684

8785
def get_dataset(config_path: str, start_time: str, end_time: str, renewable: str = "pv") -> Dataset:
8886
"""Get the dataset for the given renewable type."""
89-
if renewable== "pv":
90-
dataset_cls = PVNetUKRegionalDataset
87+
if renewable == "pv":
88+
dataset_cls = PVNetUKRegionalDataset
9189
elif renewable in ["wind", "pv_india", "pv_site"]:
9290
raise NotImplementedError
9391
else:
9492
raise ValueError(f"Unknown renewable: {renewable}")
95-
93+
9694
return dataset_cls(config_path, start_time=start_time, end_time=end_time)
9795

9896

9997
def save_samples_with_dataloader(
100-
dataset: Dataset,
101-
save_dir: str,
102-
num_samples: int,
103-
dataloader_kwargs: dict,
104-
renewable: str = "pv"
98+
dataset: Dataset,
99+
save_dir: str,
100+
num_samples: int,
101+
dataloader_kwargs: dict,
102+
renewable: str = "pv",
105103
) -> None:
106104
"""Save samples from a dataset using a dataloader."""
107105
save_func = SaveFuncFactory(save_dir, renewable=renewable)
@@ -124,7 +122,7 @@ def main(config: DictConfig) -> None:
124122

125123
# Set up directory
126124
os.makedirs(config_dm.sample_output_dir, exist_ok=False)
127-
125+
128126
# Copy across configs which define the samples into the new sample directory
129127
with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f:
130128
f.write(OmegaConf.to_yaml(config_dm))
@@ -141,29 +139,29 @@ def main(config: DictConfig) -> None:
141139
batch_sampler=None,
142140
num_workers=config_dm.num_workers,
143141
collate_fn=None,
144-
pin_memory=False, Only using CPU to prepare samples so pinning is not beneficial
142+
pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
145143
drop_last=False,
146144
timeout=0,
147145
worker_init_fn=None,
148146
prefetch_factor=config_dm.prefetch_factor,
149-
persistent_workers=False, Not needed since we only enter the dataloader loop once
147+
persistent_workers=False, # Not needed since we only enter the dataloader loop once
150148
)
151149

152150
if config_dm.num_val_samples > 0:
153151
print("----- Saving val samples -----")
154-
152+
155153
val_output_dir = f"{config_dm.sample_output_dir}/val"
156-
154+
157155
# Make directory for val samples
158156
os.mkdir(val_output_dir)
159-
160-
# Get the dataset
157+
158+
# Get the dataset
161159
val_dataset = get_dataset(
162160
config_dm.configuration,
163161
*config_dm.val_period,
164162
renewable=config.renewable,
165163
)
166-
164+
167165
# Save samples
168166
save_samples_with_dataloader(
169167
dataset=val_dataset,
@@ -172,24 +170,24 @@ def main(config: DictConfig) -> None:
172170
dataloader_kwargs=dataloader_kwargs,
173171
renewable=config.renewable,
174172
)
175-
173+
176174
del val_dataset
177175

178176
if config_dm.num_train_samples > 0:
179177
print("----- Saving train samples -----")
180-
178+
181179
train_output_dir = f"{config_dm.sample_output_dir}/train"
182-
180+
183181
# Make directory for train samples
184182
os.mkdir(train_output_dir)
185-
186-
# Get the dataset
183+
184+
# Get the dataset
187185
train_dataset = get_dataset(
188186
config_dm.configuration,
189187
*config_dm.train_period,
190188
renewable=config.renewable,
191189
)
192-
190+
193191
# Save samples
194192
save_samples_with_dataloader(
195193
dataset=train_dataset,
@@ -198,7 +196,7 @@ def main(config: DictConfig) -> None:
198196
dataloader_kwargs=dataloader_kwargs,
199197
renewable=config.renewable,
200198
)
201-
199+
202200
del train_dataset
203201

204202
print("----- Saving complete -----")

0 commit comments

Comments
 (0)