Skip to content

Commit ad4bd08

Browse files
committed
Linting etc
1 parent 9e2b405 commit ad4bd08

16 files changed

+253
-112
lines changed

.flake8

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[flake8]
2+
max-line-length = 88
3+
exclude = .tox,.eggs,ci/templates,build,dist, __init__.py
4+
ignore = E741,F403,E265,W504,E226,W503,E501,E203

.isort.cfg

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[settings]
2+
profile=black

.pre-commit-config.yaml

+55-10
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,58 @@
1+
exclude: '^(\.tox|ci/templates|\.bumpversion\.cfg)(/|$)'
12
default_language_version:
2-
python: python3.9
3+
python: python3
34

45
repos:
5-
- repo: https://github.com/pre-commit/pre-commit-hooks
6-
rev: v4.1.0
7-
hooks:
8-
# list of supported hooks: https://pre-commit.com/hooks.html
9-
- id: trailing-whitespace
10-
- id: end-of-file-fixer
11-
- id: check-yaml
12-
- id: debug-statements
13-
- id: detect-private-key
6+
- repo: https://github.com/pre-commit/pre-commit-hooks
7+
rev: v4.3.0
8+
hooks:
9+
- id: trailing-whitespace
10+
- id: check-docstring-first
11+
- id: check-added-large-files
12+
- id: check-ast
13+
- id: check-merge-conflict
14+
- id: debug-statements
15+
- id: end-of-file-fixer
16+
- id: mixed-line-ending
17+
args: ['--fix=lf']
18+
19+
- repo: https://github.com/asottile/pyupgrade
20+
rev: v2.37.3
21+
hooks:
22+
- id: pyupgrade
23+
args: ['--py39-plus']
24+
25+
- repo: https://github.com/myint/autoflake
26+
rev: v1.5.3
27+
hooks:
28+
- id: autoflake
29+
args: [
30+
--in-place,
31+
--remove-all-unused-imports,
32+
--remove-unused-variables,
33+
]
34+
35+
- repo: https://github.com/pycqa/isort
36+
rev: 5.10.1
37+
hooks:
38+
- id: isort
39+
args: [
40+
--sp=.isort.cfg,
41+
]
42+
43+
- repo: https://github.com/psf/black
44+
rev: 22.8.0
45+
hooks:
46+
- id: black
47+
- id: black-jupyter
48+
49+
- repo: https://github.com/PyCQA/flake8
50+
rev: 5.0.4
51+
hooks:
52+
- id: flake8
53+
54+
- repo: https://github.com/srstevenson/nb-clean
55+
rev: 2.2.1
56+
hooks:
57+
- id: nb-clean
58+
args: ['--remove-empty-cells']

predict_pv_yield/data/dataloader.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1+
import logging
12
import os
3+
4+
import torch
25
from nowcasting_dataloader.datasets import NetCDFDataset, worker_init_fn
36
from nowcasting_dataloader.fake import FakeDataset
47
from nowcasting_dataset.config.load import load_yaml_configuration
5-
from typing import Tuple
6-
import logging
7-
import torch
88
from pytorch_lightning import LightningDataModule
99

10-
11-
1210
_LOG = logging.getLogger(__name__)
1311
_LOG.setLevel(logging.DEBUG)
1412

@@ -21,12 +19,16 @@ def get_dataloaders(
2119
cloud: str = "gcp",
2220
temp_path=".",
2321
data_path="prepared_ML_training_data/v4/",
24-
) -> Tuple:
22+
) -> tuple:
2523

26-
configuration = load_yaml_configuration(filename=f'{data_path}/configuration.yaml')
24+
# configuration = load_yaml_configuration(filename=f"{data_path}/configuration.yaml")
2725

2826
data_module = NetCDFDataModule(
29-
temp_path=temp_path, data_path=data_path, cloud=cloud, n_train_data=n_train_data, n_val_data=n_validation_data
27+
temp_path=temp_path,
28+
data_path=data_path,
29+
cloud=cloud,
30+
n_train_data=n_train_data,
31+
n_val_data=n_validation_data,
3032
)
3133

3234
train_dataloader = data_module.train_dataloader()
@@ -75,8 +77,8 @@ def __init__(
7577
self.pin_memory = pin_memory
7678
self.fake_data = fake_data
7779

78-
filename = os.path.join(data_path, 'configuration.yaml')
79-
_LOG.debug(f'Will be loading the configuration file {filename}')
80+
filename = os.path.join(data_path, "configuration.yaml")
81+
_LOG.debug(f"Will be loading the configuration file {filename}")
8082
self.configuration = load_yaml_configuration(filename=filename)
8183

8284
self.dataloader_config = dict(
@@ -98,7 +100,7 @@ def train_dataloader(self):
98100
self.n_train_data,
99101
os.path.join(self.data_path, "train"),
100102
os.path.join(self.temp_path, "train"),
101-
configuration=self.configuration
103+
configuration=self.configuration,
102104
)
103105

104106
return torch.utils.data.DataLoader(train_dataset, **self.dataloader_config)
@@ -111,7 +113,7 @@ def val_dataloader(self):
111113
self.n_val_data,
112114
os.path.join(self.data_path, "test"),
113115
os.path.join(self.temp_path, "test"),
114-
configuration=self.configuration
116+
configuration=self.configuration,
115117
)
116118

117119
return torch.utils.data.DataLoader(val_dataset, **self.dataloader_config)
@@ -125,7 +127,7 @@ def test_dataloader(self):
125127
self.n_val_data,
126128
os.path.join(self.data_path, "test"),
127129
os.path.join(self.temp_path, "test"),
128-
configuration=self.configuration
130+
configuration=self.configuration,
129131
)
130132

131133
return torch.utils.data.DataLoader(test_dataset, **self.dataloader_config)

predict_pv_yield/models/base_model.py

+48-30
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
1+
import logging
2+
3+
import numpy as np
4+
import pandas as pd
15
import pytorch_lightning as pl
26
import torch
37
import torch.nn.functional as F
4-
5-
from nowcasting_utils.visualization.visualization import plot_example
6-
from nowcasting_utils.visualization.line import plot_batch_results
8+
from nowcasting_dataloader.batch import BatchML
79
from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWP_VARIABLE_NAMES
10+
from nowcasting_utils.metrics.validation import (
11+
make_validation_results,
12+
save_validation_results_to_logger,
13+
)
814
from nowcasting_utils.models.loss import WeightedLosses
9-
from nowcasting_utils.models.metrics import mae_each_forecast_horizon, mse_each_forecast_horizon
10-
from nowcasting_dataloader.batch import BatchML
11-
from nowcasting_utils.metrics.validation import make_validation_results, save_validation_results_to_logger
12-
13-
import pandas as pd
14-
import numpy as np
15-
16-
import logging
15+
from nowcasting_utils.models.metrics import (
16+
mae_each_forecast_horizon,
17+
mse_each_forecast_horizon,
18+
)
19+
from nowcasting_utils.visualization.line import plot_batch_results
20+
from nowcasting_utils.visualization.visualization import plot_example
1721

1822
logger = logging.getLogger(__name__)
1923

@@ -75,7 +79,9 @@ def __init__(self):
7579

7680
self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len)
7781

78-
def _training_or_validation_step(self, batch, tag: str, return_model_outputs: bool = False):
82+
def _training_or_validation_step(
83+
self, batch, tag: str, return_model_outputs: bool = False
84+
):
7985
"""
8086
batch: The batch data
8187
tag: either 'Train', 'Validation' , 'Test'
@@ -120,8 +126,12 @@ def _training_or_validation_step(self, batch, tag: str, return_model_outputs: bo
120126

121127
if tag != "Train":
122128
# add metrics for each forecast horizon
123-
mse_each_forecast_horizon_metric = mse_each_forecast_horizon(output=y_hat, target=y)
124-
mae_each_forecast_horizon_metric = mae_each_forecast_horizon(output=y_hat, target=y)
129+
mse_each_forecast_horizon_metric = mse_each_forecast_horizon(
130+
output=y_hat, target=y
131+
)
132+
mae_each_forecast_horizon_metric = mae_each_forecast_horizon(
133+
output=y_hat, target=y
134+
)
125135

126136
metrics_mse = {
127137
f"MSE_forecast_horizon_{i}/{tag}": mse_each_forecast_horizon_metric[i]
@@ -167,7 +177,9 @@ def validation_step(self, batch: BatchML, batch_idx):
167177
if batch_idx in [0, 1, 2, 3, 4]:
168178

169179
# make sure the interesting example doesnt go above the batch size
170-
INTERESTING_EXAMPLES = (i for i in INTERESTING_EXAMPLES if i < self.batch_size)
180+
INTERESTING_EXAMPLES = (
181+
i for i in INTERESTING_EXAMPLES if i < self.batch_size
182+
)
171183

172184
for example_i in INTERESTING_EXAMPLES:
173185
# 1. Plot example
@@ -187,7 +199,7 @@ def validation_step(self, batch: BatchML, batch_idx):
187199
self.logger.experiment[-1].log_image(name, fig)
188200
try:
189201
fig.close()
190-
except Exception as _:
202+
except Exception:
191203
# could not close figure
192204
pass
193205

@@ -212,26 +224,30 @@ def validation_step(self, batch: BatchML, batch_idx):
212224
]
213225

214226
# plot and save to logger
215-
fig = plot_batch_results(model_name=self.name, y=y, y_hat=y_hat, x=time, x_hat=time_hat)
227+
fig = plot_batch_results(
228+
model_name=self.name, y=y, y_hat=y_hat, x=time, x_hat=time_hat
229+
)
216230
fig.write_html(f"temp_{batch_idx}.html")
217231
try:
218232
self.logger.experiment[-1][name].upload(f"temp_{batch_idx}.html")
219-
except:
233+
except Exception:
220234
pass
221235

222236
# save validation results
223-
capacity = batch.gsp.gsp_capacity[:,-self.forecast_len_30:,0].cpu().numpy()
237+
capacity = batch.gsp.gsp_capacity[:, -self.forecast_len_30 :, 0].cpu().numpy()
224238
predictions = model_output.cpu().numpy()
225-
truths = batch.gsp.gsp_yield[:, -self.forecast_len_30:, 0].cpu().numpy()
239+
truths = batch.gsp.gsp_yield[:, -self.forecast_len_30 :, 0].cpu().numpy()
226240
predictions = predictions * capacity
227241
truths = truths * capacity
228242

229-
results = make_validation_results(truths_mw=truths,
230-
predictions_mw=predictions,
231-
capacity_mwp=capacity,
232-
gsp_ids=batch.gsp.gsp_id[:, 0].cpu(),
233-
batch_idx=batch_idx,
234-
t0_datetimes_utc=pd.to_datetime(batch.metadata.t0_datetime_utc))
243+
results = make_validation_results(
244+
truths_mw=truths,
245+
predictions_mw=predictions,
246+
capacity_mwp=capacity,
247+
gsp_ids=batch.gsp.gsp_id[:, 0].cpu(),
248+
batch_idx=batch_idx,
249+
t0_datetimes_utc=pd.to_datetime(batch.metadata.t0_datetime_utc),
250+
)
235251

236252
# append so in 'validation_epoch_end' the file is saved
237253
if batch_idx == 0:
@@ -244,10 +260,12 @@ def validation_epoch_end(self, outputs):
244260

245261
logger.info("Validation epoch end")
246262

247-
save_validation_results_to_logger(results_dfs=self.results_dfs,
248-
results_file_name=self.results_file_name,
249-
current_epoch=self.current_epoch,
250-
logger=self.logger)
263+
save_validation_results_to_logger(
264+
results_dfs=self.results_dfs,
265+
results_file_name=self.results_file_name,
266+
current_epoch=self.current_epoch,
267+
logger=self.logger,
268+
)
251269

252270
def test_step(self, batch, batch_idx):
253271
self._training_or_validation_step(batch, tag="Test")

predict_pv_yield/models/conv3d/model_sat_nwp.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import torch
44
import torch.nn.functional as F
5+
from nowcasting_dataloader.batch import BatchML
56
from torch import nn
67

78
from predict_pv_yield.models.base_model import BaseModel
8-
from nowcasting_dataloader.batch import BatchML
99

1010
logging.basicConfig()
1111
_LOG = logging.getLogger("predict_pv_yield")
@@ -139,10 +139,12 @@ def __init__(
139139
setattr(self, f"nwp_conv{i + 1}", layer)
140140

141141
self.nwp_fc1 = nn.Linear(
142-
in_features=self.nwp_cnn_output_size, out_features=self.fc1_output_features
142+
in_features=self.nwp_cnn_output_size,
143+
out_features=self.fc1_output_features,
143144
)
144145
self.nwp_fc2 = nn.Linear(
145-
in_features=self.fc1_output_features, out_features=self.number_of_nwp_features
146+
in_features=self.fc1_output_features,
147+
out_features=self.number_of_nwp_features,
146148
)
147149

148150
if self.embedding_dem:
@@ -152,22 +154,29 @@ def __init__(
152154

153155
if self.include_pv_yield_history:
154156
self.pv_fc1 = nn.Linear(
155-
in_features=self.number_of_pv_samples_per_batch * (self.history_len_5 + 1),
157+
in_features=self.number_of_pv_samples_per_batch
158+
* (self.history_len_5 + 1),
156159
out_features=128,
157160
)
158161

159162
fc3_in_features = self.fc2_output_features
160163
if include_pv_or_gsp_yield_history:
161-
fc3_in_features += self.number_of_samples_per_batch * (self.history_len_30 + 1)
164+
fc3_in_features += self.number_of_samples_per_batch * (
165+
self.history_len_30 + 1
166+
)
162167
if include_nwp:
163168
fc3_in_features += 128
164169
if self.embedding_dem:
165170
fc3_in_features += self.embedding_dem
166171
if self.include_pv_yield_history:
167172
fc3_in_features += 128
168173

169-
self.fc3 = nn.Linear(in_features=fc3_in_features, out_features=self.fc3_output_features)
170-
self.fc4 = nn.Linear(in_features=self.fc3_output_features, out_features=self.forecast_len)
174+
self.fc3 = nn.Linear(
175+
in_features=fc3_in_features, out_features=self.fc3_output_features
176+
)
177+
self.fc4 = nn.Linear(
178+
in_features=self.fc3_output_features, out_features=self.forecast_len
179+
)
171180
# self.fc5 = nn.Linear(in_features=32, out_features=8)
172181
# self.fc6 = nn.Linear(in_features=8, out_features=1)
173182

@@ -201,15 +210,20 @@ def forward(self, x):
201210
if self.include_pv_or_gsp_yield_history:
202211
if self.output_variable == "gsp_yield":
203212
pv_yield_history = (
204-
x.gsp.gsp_yield[:, : self.history_len_30 + 1].nan_to_num(nan=0.0).float()
213+
x.gsp.gsp_yield[:, : self.history_len_30 + 1]
214+
.nan_to_num(nan=0.0)
215+
.float()
205216
)
206217
else:
207218
pv_yield_history = (
208-
x.pv.pv_yield[:, : self.history_len_30 + 1].nan_to_num(nan=0.0).float()
219+
x.pv.pv_yield[:, : self.history_len_30 + 1]
220+
.nan_to_num(nan=0.0)
221+
.float()
209222
)
210223

211224
pv_yield_history = pv_yield_history.reshape(
212-
pv_yield_history.shape[0], pv_yield_history.shape[1] * pv_yield_history.shape[2]
225+
pv_yield_history.shape[0],
226+
pv_yield_history.shape[1] * pv_yield_history.shape[2],
213227
)
214228
# join up
215229
out = torch.cat((out, pv_yield_history), dim=1)
@@ -218,11 +232,14 @@ def forward(self, x):
218232
if self.include_pv_yield_history:
219233
# just take the first 128
220234
pv_yield_history = (
221-
x.pv.pv_yield[:, : self.history_len_5 + 1, :128].nan_to_num(nan=0.0).float()
235+
x.pv.pv_yield[:, : self.history_len_5 + 1, :128]
236+
.nan_to_num(nan=0.0)
237+
.float()
222238
)
223239

224240
pv_yield_history = pv_yield_history.reshape(
225-
pv_yield_history.shape[0], pv_yield_history.shape[1] * pv_yield_history.shape[2]
241+
pv_yield_history.shape[0],
242+
pv_yield_history.shape[1] * pv_yield_history.shape[2],
226243
)
227244
pv_yield_history = F.relu(self.pv_fc1(pv_yield_history))
228245

0 commit comments

Comments
 (0)