Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 19, 2024
1 parent 9cfc8a9 commit ea77b0e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 33 deletions.
53 changes: 24 additions & 29 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@
import pandas as pd
import torch
import torch.nn.functional as F
import yaml
import wandb
import yaml
from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils._deprecation import _deprecate_positional_args
from ocf_datapipes.batch import BatchKey
from ocf_ml_metrics.evaluation.evaluation import evaluation
from ocf_ml_metrics.metrics.errors import common_metrics

from pvnet.models.utils import (
BatchAccumulator,
Expand Down Expand Up @@ -275,9 +274,7 @@ def __init__(
self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30)

self._accumulated_metrics = MetricAccumulator()
self._accumulated_batches = BatchAccumulator(
key_to_keep=self._target_key
)
self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key)
self._accumulated_y_hat = PredAccumulator()

@property
Expand Down Expand Up @@ -434,14 +431,14 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
# We only create the figure every 8 log steps
# This was reduced as it was creating figures too often
if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0:
fig = plot_batch_forecasts(
batch,
y_hat,
batch_idx,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
)
fig.savefig("latest_logged_train_batch.png")
fig = plot_batch_forecasts(
batch,
y_hat,
batch_idx,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
)
fig.savefig("latest_logged_train_batch.png")

def training_step(self, batch, batch_idx):
"""Run training step"""
Expand Down Expand Up @@ -477,18 +474,18 @@ def validation_step(self, batch: dict, batch_idx):
# for each step in the forecast horizon
# This is needed for the custom plot
# And needs to be in order of step
x_values = [int(k.split("_")[-1].split("/")[0]) for k in logged_losses.keys() if "MAE_horizon/step" in k]
x_values = [
int(k.split("_")[-1].split("/")[0])
for k in logged_losses.keys()
if "MAE_horizon/step" in k
]
y_values = []
for x in x_values:
y_values.append(logged_losses[f"MAE_horizon/step_{x:03}/val"])
per_step_losses = [[x, y] for (x, y) in zip(x_values, y_values)]
table = wandb.Table(data=per_step_losses, columns=["timestep", "MAE"])
wandb.log(
{
"mae_vs_timestep": wandb.plot.line(
table, "timestep", "MAE", title="MAE vs Timestep"
)
}
{"mae_vs_timestep": wandb.plot.line(table, "timestep", "MAE", title="MAE vs Timestep")}
)

self.log_dict(
Expand All @@ -503,9 +500,7 @@ def validation_step(self, batch: dict, batch_idx):
# Store these temporarily under self
if not hasattr(self, "_val_y_hats"):
self._val_y_hats = PredAccumulator()
self._val_batches = BatchAccumulator(
key_to_keep=self._target_key_name
)
self._val_batches = BatchAccumulator(key_to_keep=self._target_key_name)

self._val_y_hats.append(y_hat)
self._val_batches.append(batch)
Expand All @@ -515,16 +510,16 @@ def validation_step(self, batch: dict, batch_idx):
batch = self._val_batches.flush()

fig = plot_batch_forecasts(
batch,
y_hat,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
batch,
y_hat,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
)

self.logger.experiment.log(
{
f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig),
}
{
f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig),
}
)
del self._val_y_hats
del self._val_batches
Expand Down
9 changes: 6 additions & 3 deletions pvnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,12 @@ def __bool__(self):

# @staticmethod
def _filter_batch_dict(self, d):
keep_keys = (
[BatchKey[self.key_to_keep], BatchKey[f"{self.key_to_keep}_id"], BatchKey[f"{self.key_to_keep}_t0_idx"], BatchKey[f"{self.key_to_keep}_time_utc"]]
)
keep_keys = [
BatchKey[self.key_to_keep],
BatchKey[f"{self.key_to_keep}_id"],
BatchKey[f"{self.key_to_keep}_t0_idx"],
BatchKey[f"{self.key_to_keep}_time_utc"],
]
return {k: v for k, v in d.items() if k in keep_keys}

def append(self, batch: dict[BatchKey, list[torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _get_numpy(key):
y = batch[y_key][:, :, 0].cpu().numpy() # Select the one it is trained on
y_hat = y_hat.cpu().numpy()
gsp_ids = batch[y_id_key][:, 0].cpu().numpy().squeeze()
t0_idx = int(batch[t0_idx_key])
int(batch[t0_idx_key])
plotting_name = key_to_plot.upper()

gsp_ids = batch[y_id_key].cpu().numpy().squeeze()
Expand Down

0 comments on commit ea77b0e

Please sign in to comment.