diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 842639a4..0a7f7cd2 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -10,8 +10,8 @@ import pandas as pd import torch import torch.nn.functional as F -import yaml import wandb +import yaml from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME from huggingface_hub.file_download import hf_hub_download @@ -19,7 +19,6 @@ from huggingface_hub.utils._deprecation import _deprecate_positional_args from ocf_datapipes.batch import BatchKey from ocf_ml_metrics.evaluation.evaluation import evaluation -from ocf_ml_metrics.metrics.errors import common_metrics from pvnet.models.utils import ( BatchAccumulator, @@ -275,9 +274,7 @@ def __init__( self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30) self._accumulated_metrics = MetricAccumulator() - self._accumulated_batches = BatchAccumulator( - key_to_keep=self._target_key_name - ) + self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key_name) self._accumulated_y_hat = PredAccumulator() @property @@ -434,14 +431,14 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat): # We only create the figure every 8 log steps # This was reduced as it was creating figures too often if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0: - fig = plot_batch_forecasts( - batch, - y_hat, - batch_idx, - quantiles=self.output_quantiles, - key_to_plot=self._target_key_name, - ) - fig.savefig("latest_logged_train_batch.png") + fig = plot_batch_forecasts( + batch, + y_hat, + batch_idx, + quantiles=self.output_quantiles, + key_to_plot=self._target_key_name, + ) + fig.savefig("latest_logged_train_batch.png") def training_step(self, batch, batch_idx): """Run training step""" @@ -477,18 +474,18 @@ def validation_step(self, batch: dict, batch_idx): # for each step in the forecast horizon # This is needed for the custom plot # And needs to be in order of step - x_values = [int(k.split("_")[-1].split("/")[0]) for k in logged_losses.keys() if "MAE_horizon/step" in k] + x_values = [ + int(k.split("_")[-1].split("/")[0]) + for k in logged_losses.keys() + if "MAE_horizon/step" in k + ] y_values = [] for x in x_values: y_values.append(logged_losses[f"MAE_horizon/step_{x:03}/val"]) per_step_losses = [[x, y] for (x, y) in zip(x_values, y_values)] table = wandb.Table(data=per_step_losses, columns=["timestep", "MAE"]) wandb.log( - { - "mae_vs_timestep": wandb.plot.line( - table, "timestep", "MAE", title="MAE vs Timestep" - ) - } + {"mae_vs_timestep": wandb.plot.line(table, "timestep", "MAE", title="MAE vs Timestep")} ) self.log_dict( @@ -503,9 +500,7 @@ def validation_step(self, batch: dict, batch_idx): # Store these temporarily under self if not hasattr(self, "_val_y_hats"): self._val_y_hats = PredAccumulator() - self._val_batches = BatchAccumulator( - key_to_keep=self._target_key_name - ) + self._val_batches = BatchAccumulator(key_to_keep=self._target_key_name) self._val_y_hats.append(y_hat) self._val_batches.append(batch) @@ -515,16 +510,16 @@ def validation_step(self, batch: dict, batch_idx): batch = self._val_batches.flush() fig = plot_batch_forecasts( - batch, - y_hat, - quantiles=self.output_quantiles, - key_to_plot=self._target_key_name, + batch, + y_hat, + quantiles=self.output_quantiles, + key_to_plot=self._target_key_name, ) self.logger.experiment.log( - { - f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig), - } + { + f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig), + } ) del self._val_y_hats del self._val_batches diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index fda96cf7..f6bfeeb8 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -58,7 +58,7 @@ def __init__( wind_history_minutes: Optional[int] = None, optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), target_key: str = "gsp", - interval_minutes: int = 30, + interval_minutes: int = 30, ): """Neural network which combines information from different sources. @@ -121,7 +121,7 @@ def __init__( optimizer=optimizer, output_quantiles=output_quantiles, target_key=target_key, - interval_minutes=interval_minutes + interval_minutes=interval_minutes, ) # Number of features expected by the output_network