Skip to content

Commit 08a55ab

Browse files
committed
linting and clean up
1 parent 9faaeec commit 08a55ab

11 files changed

+100
-35
lines changed

pvnet/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
"""PVNet"""
12
__version__ = "0.0.8"

pvnet/callbacks.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
1-
"""Custom callbacks developed to be able to use early stopping and learning rate finder even when
2-
pretraining parts of the network.
1+
"""Custom callbacks
32
"""
43
from lightning.pytorch import Trainer
54
from lightning.pytorch.callbacks import BaseFinetuning, EarlyStopping, LearningRateFinder
65
from lightning.pytorch.trainer.states import TrainerFn
76

87

98
class PhaseEarlyStopping(EarlyStopping):
9+
"""Monitor a validation metric and stop training when it stops improving.
10+
11+
Only functions in a specific phase of training.
12+
"""
1013

1114
training_phase = None
1215

1316
def switch_phase(self, phase: str):
17+
"""Switch phase of callback"""
1418
if phase == self.training_phase:
1519
self.activate()
1620
else:
1721
self.deactivate()
1822

1923
def deactivate(self):
24+
"""Deactivate callback"""
2025
self.active = False
2126

2227
def activate(self):
28+
"""Activate callback"""
2329
self.active = True
2430

2531
def _should_skip_check(self, trainer: Trainer) -> bool:
@@ -30,21 +36,34 @@ def _should_skip_check(self, trainer: Trainer) -> bool:
3036

3137

3238
class PretrainEarlyStopping(EarlyStopping):
39+
"""Monitor a validation metric and stop training when it stops improving.
40+
41+
Only functions in the 'pretrain' phase of training.
42+
"""
3343
training_phase = "pretrain"
3444

3545

3646
class MainEarlyStopping(EarlyStopping):
47+
"""Monitor a validation metric and stop training when it stops improving.
48+
49+
Only functions in the 'main' phase of training.
50+
"""
3751
training_phase = "main"
3852

3953

4054
class PretrainFreeze(BaseFinetuning):
55+
"""Freeze the satellite and NWP encoders during pretraining
56+
"""
4157

4258
training_phase = "pretrain"
4359

4460
def __init__(self):
61+
"""Freeze the satellite and NWP encoders during pretraining
62+
"""
4563
super().__init__()
4664

4765
def freeze_before_training(self, pl_module):
66+
"""Freeze satellite and NWP encoders before training start"""
4867
# freeze any module you want
4968
modules = []
5069
if pl_module.include_sat:
@@ -54,6 +73,7 @@ def freeze_before_training(self, pl_module):
5473
self.freeze(modules)
5574

5675
def finetune_function(self, pl_module, current_epoch, optimizer):
76+
"""Unfreeze satellite and NWP encoders"""
5777
if not self.active:
5878
modules = []
5979
if pl_module.include_sat:
@@ -67,15 +87,18 @@ def finetune_function(self, pl_module, current_epoch, optimizer):
6787
)
6888

6989
def switch_phase(self, phase: str):
90+
"""Switch phase of callback"""
7091
if phase == self.training_phase:
7192
self.activate()
7293
else:
7394
self.deactivate()
7495

7596
def deactivate(self):
97+
"""Deactivate callback"""
7698
self.active = False
7799

78100
def activate(self):
101+
"""Activate callback"""
79102
self.active = True
80103

81104

@@ -85,18 +108,23 @@ class PhasedLearningRateFinder(LearningRateFinder):
85108
active = True
86109

87110
def on_fit_start(self, *args, **kwargs):
111+
"""Do nothing"""
88112
return
89113

90114
def on_train_epoch_start(self, trainer, pl_module):
115+
"""Run learning rate finder on epoch start and then deactivate"""
91116
if self.active:
92117
self.lr_find(trainer, pl_module)
93118
self.deactivate()
94119

95120
def switch_phase(self, phase: str):
121+
"""Switch training phase"""
96122
self.activate()
97123

98124
def deactivate(self):
125+
"""Deactivate callback"""
99126
self.active = False
100127

101128
def activate(self):
129+
"""Activate callback"""
102130
self.active = True

pvnet/data/__init__.py

Whitespace-only changes.

pvnet/data/datamodule.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,6 @@ def batch_to_tensor(batch):
2020
return batch
2121

2222

23-
def print_yaml(path):
24-
print(f"{path} :")
25-
with open(path, mode="r") as stream:
26-
print("".join(stream.readlines()))
27-
28-
2923
def split_batches(batch):
3024
"""Splits a single batch of data."""
3125
n_samples = batch[BatchKey.gsp].shape[0]
@@ -46,6 +40,7 @@ class BatchSplitter(IterDataPipe):
4640
"""Pipeline step to split batches of data and yield single examples"""
4741

4842
def __init__(self, source_datapipe: IterDataPipe):
43+
"""Pipeline step to split batches of data and yield single examples"""
4944
self.source_datapipe = source_datapipe
5045

5146
def __iter__(self):

pvnet/models/utils.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
"""Utility functions"""
2+
13
import numpy as np
24
import torch
35
from ocf_datapipes.utils.consts import BatchKey
46

57

68
class PredAccumulator:
7-
"""A class for accumulating y-predictions when using grad accumulation and the batch size is
8-
small.
9+
"""A class for accumulating y-predictions using grad accumulation and small batch size.
910
1011
Attributes:
1112
_y_hats (list[torch.Tensor]): List of prediction tensors
1213
"""
1314

1415
def __init__(self):
16+
"""Prediction accumulator"""
1517
self._y_hats = []
1618

1719
def __bool__(self):
@@ -22,43 +24,51 @@ def append(self, y_hat: torch.Tensor):
2224
self._y_hats += [y_hat]
2325

2426
def flush(self) -> torch.Tensor:
27+
"""Return all appended predictions a single torch tensor and remove from accumulated store.
28+
"""
2529
y_hat = torch.cat(self._y_hats, dim=0)
2630
self._y_hats = []
2731
return y_hat
2832

2933

3034
class DictListAccumulator:
35+
"""Abstract class for accumulating dictionaries of lists"""
3136
@staticmethod
32-
def dict_list_append(d1, d2):
37+
def _dict_list_append(d1, d2):
3338
for k, v in d2.items():
3439
d1[k] += [v]
3540

3641
@staticmethod
37-
def dict_init_list(d):
42+
def _dict_init_list(d):
3843
return {k: [v] for k, v in d.items()}
3944

4045

4146
class MetricAccumulator(DictListAccumulator):
42-
"""A class for accumulating, and finding the mean of logging metrics when using grad
47+
"""Dictionary of metrics accumulator.
48+
49+
A class for accumulating, and finding the mean of logging metrics when using grad
4350
accumulation and the batch size is small.
4451
4552
Attributes:
4653
_metrics (Dict[str, list[float]]): Dictionary containing lists of metrics.
4754
"""
4855

4956
def __init__(self):
57+
"""Dictionary of metrics accumulator."""
5058
self._metrics = {}
5159

5260
def __bool__(self):
5361
return self._metrics != {}
5462

5563
def append(self, loss_dict: dict[str, float]):
64+
"""Append lictionary of metrics to self"""
5665
if not self:
57-
self._metrics = self.dict_init_list(loss_dict)
66+
self._metrics = self._dict_init_list(loss_dict)
5867
else:
59-
self.dict_list_append(self._metrics, loss_dict)
68+
self._dict_list_append(self._metrics, loss_dict)
6069

6170
def flush(self) -> dict[str, float]:
71+
"""Calculate mean of all accumulated metrics and clear"""
6272
mean_metrics = {k: np.mean(v) for k, v in self._metrics.items()}
6373
self._metrics = {}
6474
return mean_metrics
@@ -72,23 +82,26 @@ class BatchAccumulator(DictListAccumulator):
7282
"""
7383

7484
def __init__(self):
85+
"""Batch accumulator"""
7586
self._batches = {}
7687

7788
def __bool__(self):
7889
return self._batches != {}
7990

8091
@staticmethod
81-
def filter_batch_dict(d):
92+
def _filter_batch_dict(d):
8293
keep_keys = [BatchKey.gsp, BatchKey.gsp_id, BatchKey.gsp_t0_idx, BatchKey.gsp_time_utc]
8394
return {k: v for k, v in d.items() if k in keep_keys}
8495

8596
def append(self, batch: dict[BatchKey, list[torch.Tensor]]):
97+
"""Append batch to self"""
8698
if not self:
87-
self._batches = self.dict_init_list(self.filter_batch_dict(batch))
99+
self._batches = self._dict_init_list(self._filter_batch_dict(batch))
88100
else:
89-
self.dict_list_append(self._batches, self.filter_batch_dict(batch))
101+
self._dict_list_append(self._batches, self._filter_batch_dict(batch))
90102

91103
def flush(self) -> dict[BatchKey, list[torch.Tensor]]:
104+
"""Concatenate all accumulated batches, return, and clear self"""
92105
batch = {}
93106
for k, v in self._batches.items():
94107
if k == BatchKey.gsp_t0_idx:

pvnet/optimizers.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88

99
class AbstractOptimizer(ABC):
10-
"""Optimizer classes will be used by model like:
10+
"""Abstract class for optimizer
11+
12+
Optimizer classes will be used by model like:
1113
> OptimizerGenerator = AbstractOptimizer()
1214
> optimizer = OptimizerGenerator(model.parameters())
1315
The returned object `optimizer` must be something that may be returned by `pytorch_lightning`'s
@@ -19,36 +21,46 @@ class AbstractOptimizer(ABC):
1921

2022
@abstractmethod
2123
def __call__(self):
24+
"""Abstract call"""
2225
pass
2326

2427

2528
class Adam(AbstractOptimizer):
29+
"""Adam optimizer"""
2630
def __init__(self, lr=0.0005, **kwargs):
31+
"""Adam optimizer"""
2732
self.lr = lr
2833
self.kwargs = kwargs
2934

3035
def __call__(self, model_parameters):
36+
"""Return optimizer"""
3137
return torch.optim.Adam(model_parameters, lr=self.lr, **self.kwargs)
3238

3339

3440
class AdamW(AbstractOptimizer):
41+
"""AdamW optimizer"""
3542
def __init__(self, lr=0.0005, **kwargs):
43+
"""AdamW optimizer"""
3644
self.lr = lr
3745
self.kwargs = kwargs
3846

3947
def __call__(self, model_parameters):
48+
"""Return optimizer"""
4049
return torch.optim.AdamW(model_parameters, lr=self.lr, **self.kwargs)
4150

4251

4352
class AdamWReduceLROnPlateau(AbstractOptimizer):
53+
"""AdamW optimizer and reduce on plateau scheduler"""
4454
def __init__(self, lr=0.0005, patience=3, factor=0.5, threshold=2e-4, **opt_kwargs):
55+
"""AdamW optimizer and reduce on plateau scheduler"""
4556
self.lr = lr
4657
self.patience = patience
4758
self.factor = factor
4859
self.threshold = threshold
4960
self.opt_kwargs = opt_kwargs
5061

5162
def __call__(self, model_parameters):
63+
"""Return optimizer"""
5264
opt = torch.optim.AdamW(model_parameters, lr=self.lr, **self.opt_kwargs)
5365
sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
5466
opt,

pvnet/training.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Training"""
2+
13
from typing import Optional
24

35
import hydra
@@ -19,14 +21,15 @@
1921
torch.set_default_dtype(torch.float32)
2022

2123

22-
def callbacks_to_phase(callbacks, phase):
24+
def _callbacks_to_phase(callbacks, phase):
2325
for c in callbacks:
2426
if hasattr(c, "switch_phase"):
2527
c.switch_phase(phase)
2628

2729

2830
def train(config: DictConfig) -> Optional[float]:
2931
"""Contains training pipeline.
32+
3033
Instantiates all PyTorch Lightning objects from config.
3134
3235
Args:
@@ -69,7 +72,7 @@ def train(config: DictConfig) -> Optional[float]:
6972
should_pretrain |= hasattr(c, "training_phase") and c.training_phase == "pretrain"
7073

7174
if should_pretrain:
72-
callbacks_to_phase(callbacks, "pretrain")
75+
_callbacks_to_phase(callbacks, "pretrain")
7376

7477
trainer: Trainer = hydra.utils.instantiate(
7578
config.trainer,
@@ -83,7 +86,7 @@ def train(config: DictConfig) -> Optional[float]:
8386
datamodule.block_nwp_and_sat = True
8487
trainer.fit(model=model, datamodule=datamodule)
8588

86-
callbacks_to_phase(callbacks, "main")
89+
_callbacks_to_phase(callbacks, "main")
8790

8891
datamodule.block_nwp_and_sat = False
8992
trainer.should_stop = False

pvnet/utils.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Utils"""
2+
13
import logging
24
import os
35
import warnings
@@ -60,7 +62,9 @@ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
6062

6163

6264
def extras(config: DictConfig) -> None:
63-
"""A couple of optional utilities, controlled by main config file:
65+
"""A couple of optional utilities.
66+
67+
Controlled by main config file:
6468
- disabling warnings
6569
- easier access to debug mode
6670
- forcing debug friendly configuration
@@ -143,6 +147,7 @@ def print_config(
143147

144148

145149
def empty(*args, **kwargs):
150+
"""Returns nothing"""
146151
pass
147152

148153

@@ -209,6 +214,7 @@ def finish(
209214

210215

211216
def plot_batch_forecasts(batch, y_hat, batch_idx=None):
217+
"""Plot a batch of data and the forecast from that batch"""
212218
def _get_numpy(key):
213219
return batch[key].cpu().numpy().squeeze()
214220

@@ -254,6 +260,7 @@ def _get_numpy(key):
254260

255261

256262
def construct_ocf_ml_metrics_batch_df(batch, y, y_hat):
263+
"""Helper function tot construct DataFrame for ocf_ml_metrics"""
257264
def _repeat(x):
258265
return np.repeat(x.squeeze(), n_times)
259266

0 commit comments

Comments
 (0)