Skip to content

Commit ea77b0e

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9cfc8a9 commit ea77b0e

File tree

3 files changed

+31
-33
lines changed

3 files changed

+31
-33
lines changed

pvnet/models/base_model.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,15 @@
1010
import pandas as pd
1111
import torch
1212
import torch.nn.functional as F
13-
import yaml
1413
import wandb
14+
import yaml
1515
from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin
1616
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
1717
from huggingface_hub.file_download import hf_hub_download
1818
from huggingface_hub.hf_api import HfApi
1919
from huggingface_hub.utils._deprecation import _deprecate_positional_args
2020
from ocf_datapipes.batch import BatchKey
2121
from ocf_ml_metrics.evaluation.evaluation import evaluation
22-
from ocf_ml_metrics.metrics.errors import common_metrics
2322

2423
from pvnet.models.utils import (
2524
BatchAccumulator,
@@ -275,9 +274,7 @@ def __init__(
275274
self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30)
276275

277276
self._accumulated_metrics = MetricAccumulator()
278-
self._accumulated_batches = BatchAccumulator(
279-
key_to_keep=self._target_key
280-
)
277+
self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key)
281278
self._accumulated_y_hat = PredAccumulator()
282279

283280
@property
@@ -434,14 +431,14 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
434431
# We only create the figure every 8 log steps
435432
# This was reduced as it was creating figures too often
436433
if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0:
437-
fig = plot_batch_forecasts(
438-
batch,
439-
y_hat,
440-
batch_idx,
441-
quantiles=self.output_quantiles,
442-
key_to_plot=self._target_key_name,
443-
)
444-
fig.savefig("latest_logged_train_batch.png")
434+
fig = plot_batch_forecasts(
435+
batch,
436+
y_hat,
437+
batch_idx,
438+
quantiles=self.output_quantiles,
439+
key_to_plot=self._target_key_name,
440+
)
441+
fig.savefig("latest_logged_train_batch.png")
445442

446443
def training_step(self, batch, batch_idx):
447444
"""Run training step"""
@@ -477,18 +474,18 @@ def validation_step(self, batch: dict, batch_idx):
477474
# for each step in the forecast horizon
478475
# This is needed for the custom plot
479476
# And needs to be in order of step
480-
x_values = [int(k.split("_")[-1].split("/")[0]) for k in logged_losses.keys() if "MAE_horizon/step" in k]
477+
x_values = [
478+
int(k.split("_")[-1].split("/")[0])
479+
for k in logged_losses.keys()
480+
if "MAE_horizon/step" in k
481+
]
481482
y_values = []
482483
for x in x_values:
483484
y_values.append(logged_losses[f"MAE_horizon/step_{x:03}/val"])
484485
per_step_losses = [[x, y] for (x, y) in zip(x_values, y_values)]
485486
table = wandb.Table(data=per_step_losses, columns=["timestep", "MAE"])
486487
wandb.log(
487-
{
488-
"mae_vs_timestep": wandb.plot.line(
489-
table, "timestep", "MAE", title="MAE vs Timestep"
490-
)
491-
}
488+
{"mae_vs_timestep": wandb.plot.line(table, "timestep", "MAE", title="MAE vs Timestep")}
492489
)
493490

494491
self.log_dict(
@@ -503,9 +500,7 @@ def validation_step(self, batch: dict, batch_idx):
503500
# Store these temporarily under self
504501
if not hasattr(self, "_val_y_hats"):
505502
self._val_y_hats = PredAccumulator()
506-
self._val_batches = BatchAccumulator(
507-
key_to_keep=self._target_key_name
508-
)
503+
self._val_batches = BatchAccumulator(key_to_keep=self._target_key_name)
509504

510505
self._val_y_hats.append(y_hat)
511506
self._val_batches.append(batch)
@@ -515,16 +510,16 @@ def validation_step(self, batch: dict, batch_idx):
515510
batch = self._val_batches.flush()
516511

517512
fig = plot_batch_forecasts(
518-
batch,
519-
y_hat,
520-
quantiles=self.output_quantiles,
521-
key_to_plot=self._target_key_name,
513+
batch,
514+
y_hat,
515+
quantiles=self.output_quantiles,
516+
key_to_plot=self._target_key_name,
522517
)
523518

524519
self.logger.experiment.log(
525-
{
526-
f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig),
527-
}
520+
{
521+
f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig),
522+
}
528523
)
529524
del self._val_y_hats
530525
del self._val_batches

pvnet/models/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,12 @@ def __bool__(self):
9999

100100
# @staticmethod
101101
def _filter_batch_dict(self, d):
102-
keep_keys = (
103-
[BatchKey[self.key_to_keep], BatchKey[f"{self.key_to_keep}_id"], BatchKey[f"{self.key_to_keep}_t0_idx"], BatchKey[f"{self.key_to_keep}_time_utc"]]
104-
)
102+
keep_keys = [
103+
BatchKey[self.key_to_keep],
104+
BatchKey[f"{self.key_to_keep}_id"],
105+
BatchKey[f"{self.key_to_keep}_t0_idx"],
106+
BatchKey[f"{self.key_to_keep}_time_utc"],
107+
]
105108
return {k: v for k, v in d.items() if k in keep_keys}
106109

107110
def append(self, batch: dict[BatchKey, list[torch.Tensor]]):

pvnet/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _get_numpy(key):
255255
y = batch[y_key][:, :, 0].cpu().numpy() # Select the one it is trained on
256256
y_hat = y_hat.cpu().numpy()
257257
gsp_ids = batch[y_id_key][:, 0].cpu().numpy().squeeze()
258-
t0_idx = int(batch[t0_idx_key])
258+
int(batch[t0_idx_key])
259259
plotting_name = key_to_plot.upper()
260260

261261
gsp_ids = batch[y_id_key].cpu().numpy().squeeze()

0 commit comments

Comments
 (0)