1010import pandas as pd
1111import torch
1212import torch .nn .functional as F
13- import yaml
1413import wandb
14+ import yaml
1515from huggingface_hub import ModelCard , ModelCardData , PyTorchModelHubMixin
1616from huggingface_hub .constants import CONFIG_NAME , PYTORCH_WEIGHTS_NAME
1717from huggingface_hub .file_download import hf_hub_download
1818from huggingface_hub .hf_api import HfApi
1919from huggingface_hub .utils ._deprecation import _deprecate_positional_args
2020from ocf_datapipes .batch import BatchKey
2121from ocf_ml_metrics .evaluation .evaluation import evaluation
22- from ocf_ml_metrics .metrics .errors import common_metrics
2322
2423from pvnet .models .utils import (
2524 BatchAccumulator ,
@@ -275,9 +274,7 @@ def __init__(
275274 self .weighted_losses = WeightedLosses (forecast_length = self .forecast_len_30 )
276275
277276 self ._accumulated_metrics = MetricAccumulator ()
278- self ._accumulated_batches = BatchAccumulator (
279- key_to_keep = self ._target_key
280- )
277+ self ._accumulated_batches = BatchAccumulator (key_to_keep = self ._target_key )
281278 self ._accumulated_y_hat = PredAccumulator ()
282279
283280 @property
@@ -434,14 +431,14 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
434431 # We only create the figure every 8 log steps
435432 # This was reduced as it was creating figures too often
436433 if grad_batch_num % (8 * self .trainer .log_every_n_steps ) == 0 :
437- fig = plot_batch_forecasts (
438- batch ,
439- y_hat ,
440- batch_idx ,
441- quantiles = self .output_quantiles ,
442- key_to_plot = self ._target_key_name ,
443- )
444- fig .savefig ("latest_logged_train_batch.png" )
434+ fig = plot_batch_forecasts (
435+ batch ,
436+ y_hat ,
437+ batch_idx ,
438+ quantiles = self .output_quantiles ,
439+ key_to_plot = self ._target_key_name ,
440+ )
441+ fig .savefig ("latest_logged_train_batch.png" )
445442
446443 def training_step (self , batch , batch_idx ):
447444 """Run training step"""
@@ -477,18 +474,18 @@ def validation_step(self, batch: dict, batch_idx):
477474 # for each step in the forecast horizon
478475 # This is needed for the custom plot
479476 # And needs to be in order of step
480- x_values = [int (k .split ("_" )[- 1 ].split ("/" )[0 ]) for k in logged_losses .keys () if "MAE_horizon/step" in k ]
477+ x_values = [
478+ int (k .split ("_" )[- 1 ].split ("/" )[0 ])
479+ for k in logged_losses .keys ()
480+ if "MAE_horizon/step" in k
481+ ]
481482 y_values = []
482483 for x in x_values :
483484 y_values .append (logged_losses [f"MAE_horizon/step_{ x :03} /val" ])
484485 per_step_losses = [[x , y ] for (x , y ) in zip (x_values , y_values )]
485486 table = wandb .Table (data = per_step_losses , columns = ["timestep" , "MAE" ])
486487 wandb .log (
487- {
488- "mae_vs_timestep" : wandb .plot .line (
489- table , "timestep" , "MAE" , title = "MAE vs Timestep"
490- )
491- }
488+ {"mae_vs_timestep" : wandb .plot .line (table , "timestep" , "MAE" , title = "MAE vs Timestep" )}
492489 )
493490
494491 self .log_dict (
@@ -503,9 +500,7 @@ def validation_step(self, batch: dict, batch_idx):
503500 # Store these temporarily under self
504501 if not hasattr (self , "_val_y_hats" ):
505502 self ._val_y_hats = PredAccumulator ()
506- self ._val_batches = BatchAccumulator (
507- key_to_keep = self ._target_key_name
508- )
503+ self ._val_batches = BatchAccumulator (key_to_keep = self ._target_key_name )
509504
510505 self ._val_y_hats .append (y_hat )
511506 self ._val_batches .append (batch )
@@ -515,16 +510,16 @@ def validation_step(self, batch: dict, batch_idx):
515510 batch = self ._val_batches .flush ()
516511
517512 fig = plot_batch_forecasts (
518- batch ,
519- y_hat ,
520- quantiles = self .output_quantiles ,
521- key_to_plot = self ._target_key_name ,
513+ batch ,
514+ y_hat ,
515+ quantiles = self .output_quantiles ,
516+ key_to_plot = self ._target_key_name ,
522517 )
523518
524519 self .logger .experiment .log (
525- {
526- f"val_forecast_samples/batch_idx_{ accum_batch_num } " : wandb .Image (fig ),
527- }
520+ {
521+ f"val_forecast_samples/batch_idx_{ accum_batch_num } " : wandb .Image (fig ),
522+ }
528523 )
529524 del self ._val_y_hats
530525 del self ._val_batches
0 commit comments