diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 76320825..cd5c2188 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -10,8 +10,8 @@ import pandas as pd import torch import torch.nn.functional as F -import yaml import wandb +import yaml from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download @@ -19,7 +19,6 @@ from huggingface_hub.utils._deprecation import _deprecate_positional_args from ocf_datapipes.batch import BatchKey from ocf_ml_metrics.evaluation.evaluation import evaluation -from ocf_ml_metrics.metrics.errors import common_metrics from pvnet.models.utils import ( BatchAccumulator, @@ -275,9 +274,7 @@ def __init__( self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30) self._accumulated_metrics = MetricAccumulator() - self._accumulated_batches = BatchAccumulator( - key_to_keep=self._target_key - ) + self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key) self._accumulated_y_hat = PredAccumulator() @property @@ -434,14 +431,14 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat): # We only create the figure every 8 log steps # This was reduced as it was creating figures too often if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0: - fig = plot_batch_forecasts( - batch, - y_hat, - batch_idx, - quantiles=self.output_quantiles, - key_to_plot=self._target_key_name, - ) - fig.savefig("latest_logged_train_batch.png") + fig = plot_batch_forecasts( + batch, + y_hat, + batch_idx, + quantiles=self.output_quantiles, + key_to_plot=self._target_key_name, + ) + fig.savefig("latest_logged_train_batch.png") def training_step(self, batch, batch_idx): """Run training step""" @@ -477,18 +474,18 @@ def validation_step(self, batch: dict, batch_idx): # for each step in the forecast horizon # This is needed for the custom plot # And needs to be in order of step - x_values = [int(k.split("_")[-1].split("/")[0]) for k in logged_losses.keys() if "MAE_horizon/step" in k] + x_values = [ + int(k.split("_")[-1].split("/")[0]) + for k in logged_losses.keys() + if "MAE_horizon/step" in k + ] y_values = [] for x in x_values: y_values.append(logged_losses[f"MAE_horizon/step_{x:03}/val"]) per_step_losses = [[x, y] for (x, y) in zip(x_values, y_values)] table = wandb.Table(data=per_step_losses, columns=["timestep", "MAE"]) wandb.log( - { - "mae_vs_timestep": wandb.plot.line( - table, "timestep", "MAE", title="MAE vs Timestep" - ) - } + {"mae_vs_timestep": wandb.plot.line(table, "timestep", "MAE", title="MAE vs Timestep")} ) self.log_dict( @@ -503,9 +500,7 @@ def validation_step(self, batch: dict, batch_idx): # Store these temporarily under self if not hasattr(self, "_val_y_hats"): self._val_y_hats = PredAccumulator() - self._val_batches = BatchAccumulator( - key_to_keep=self._target_key_name - ) + self._val_batches = BatchAccumulator(key_to_keep=self._target_key_name) self._val_y_hats.append(y_hat) self._val_batches.append(batch) @@ -515,16 +510,16 @@ def validation_step(self, batch: dict, batch_idx): batch = self._val_batches.flush() fig = plot_batch_forecasts( - batch, - y_hat, - quantiles=self.output_quantiles, - key_to_plot=self._target_key_name, + batch, + y_hat, + quantiles=self.output_quantiles, + key_to_plot=self._target_key_name, ) self.logger.experiment.log( - { - f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig), - } + { + f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig), + } ) del self._val_y_hats del self._val_batches diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index 811c5397..2511e0d4 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -99,9 +99,12 @@ def __bool__(self): # @staticmethod def _filter_batch_dict(self, d): - keep_keys = ( - [BatchKey[self.key_to_keep], BatchKey[f"{self.key_to_keep}_id"], BatchKey[f"{self.key_to_keep}_t0_idx"], BatchKey[f"{self.key_to_keep}_time_utc"]] - ) + keep_keys = [ + BatchKey[self.key_to_keep], + BatchKey[f"{self.key_to_keep}_id"], + BatchKey[f"{self.key_to_keep}_t0_idx"], + BatchKey[f"{self.key_to_keep}_time_utc"], + ] return {k: v for k, v in d.items() if k in keep_keys} def append(self, batch: dict[BatchKey, list[torch.Tensor]]): diff --git a/pvnet/utils.py b/pvnet/utils.py index 62f1c4a9..764fce84 100644 --- a/pvnet/utils.py +++ b/pvnet/utils.py @@ -255,7 +255,7 @@ def _get_numpy(key): y = batch[y_key][:, :, 0].cpu().numpy() # Select the one it is trained on y_hat = y_hat.cpu().numpy() gsp_ids = batch[y_id_key][:, 0].cpu().numpy().squeeze() - t0_idx = int(batch[t0_idx_key]) + int(batch[t0_idx_key]) plotting_name = key_to_plot.upper() gsp_ids = batch[y_id_key].cpu().numpy().squeeze()