10
10
import pandas as pd
11
11
import torch
12
12
import torch .nn .functional as F
13
- import yaml
14
13
import wandb
14
+ import yaml
15
15
from huggingface_hub import ModelCard , ModelCardData , PyTorchModelHubMixin
16
16
from huggingface_hub .constants import CONFIG_NAME , PYTORCH_WEIGHTS_NAME
17
17
from huggingface_hub .file_download import hf_hub_download
18
18
from huggingface_hub .hf_api import HfApi
19
19
from huggingface_hub .utils ._deprecation import _deprecate_positional_args
20
20
from ocf_datapipes .batch import BatchKey
21
21
from ocf_ml_metrics .evaluation .evaluation import evaluation
22
- from ocf_ml_metrics .metrics .errors import common_metrics
23
22
24
23
from pvnet .models .utils import (
25
24
BatchAccumulator ,
@@ -275,9 +274,7 @@ def __init__(
275
274
self .weighted_losses = WeightedLosses (forecast_length = self .forecast_len_30 )
276
275
277
276
self ._accumulated_metrics = MetricAccumulator ()
278
- self ._accumulated_batches = BatchAccumulator (
279
- key_to_keep = self ._target_key
280
- )
277
+ self ._accumulated_batches = BatchAccumulator (key_to_keep = self ._target_key )
281
278
self ._accumulated_y_hat = PredAccumulator ()
282
279
283
280
@property
@@ -434,14 +431,14 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
434
431
# We only create the figure every 8 log steps
435
432
# This was reduced as it was creating figures too often
436
433
if grad_batch_num % (8 * self .trainer .log_every_n_steps ) == 0 :
437
- fig = plot_batch_forecasts (
438
- batch ,
439
- y_hat ,
440
- batch_idx ,
441
- quantiles = self .output_quantiles ,
442
- key_to_plot = self ._target_key_name ,
443
- )
444
- fig .savefig ("latest_logged_train_batch.png" )
434
+ fig = plot_batch_forecasts (
435
+ batch ,
436
+ y_hat ,
437
+ batch_idx ,
438
+ quantiles = self .output_quantiles ,
439
+ key_to_plot = self ._target_key_name ,
440
+ )
441
+ fig .savefig ("latest_logged_train_batch.png" )
445
442
446
443
def training_step (self , batch , batch_idx ):
447
444
"""Run training step"""
@@ -477,18 +474,18 @@ def validation_step(self, batch: dict, batch_idx):
477
474
# for each step in the forecast horizon
478
475
# This is needed for the custom plot
479
476
# And needs to be in order of step
480
- x_values = [int (k .split ("_" )[- 1 ].split ("/" )[0 ]) for k in logged_losses .keys () if "MAE_horizon/step" in k ]
477
+ x_values = [
478
+ int (k .split ("_" )[- 1 ].split ("/" )[0 ])
479
+ for k in logged_losses .keys ()
480
+ if "MAE_horizon/step" in k
481
+ ]
481
482
y_values = []
482
483
for x in x_values :
483
484
y_values .append (logged_losses [f"MAE_horizon/step_{ x :03} /val" ])
484
485
per_step_losses = [[x , y ] for (x , y ) in zip (x_values , y_values )]
485
486
table = wandb .Table (data = per_step_losses , columns = ["timestep" , "MAE" ])
486
487
wandb .log (
487
- {
488
- "mae_vs_timestep" : wandb .plot .line (
489
- table , "timestep" , "MAE" , title = "MAE vs Timestep"
490
- )
491
- }
488
+ {"mae_vs_timestep" : wandb .plot .line (table , "timestep" , "MAE" , title = "MAE vs Timestep" )}
492
489
)
493
490
494
491
self .log_dict (
@@ -503,9 +500,7 @@ def validation_step(self, batch: dict, batch_idx):
503
500
# Store these temporarily under self
504
501
if not hasattr (self , "_val_y_hats" ):
505
502
self ._val_y_hats = PredAccumulator ()
506
- self ._val_batches = BatchAccumulator (
507
- key_to_keep = self ._target_key_name
508
- )
503
+ self ._val_batches = BatchAccumulator (key_to_keep = self ._target_key_name )
509
504
510
505
self ._val_y_hats .append (y_hat )
511
506
self ._val_batches .append (batch )
@@ -515,16 +510,16 @@ def validation_step(self, batch: dict, batch_idx):
515
510
batch = self ._val_batches .flush ()
516
511
517
512
fig = plot_batch_forecasts (
518
- batch ,
519
- y_hat ,
520
- quantiles = self .output_quantiles ,
521
- key_to_plot = self ._target_key_name ,
513
+ batch ,
514
+ y_hat ,
515
+ quantiles = self .output_quantiles ,
516
+ key_to_plot = self ._target_key_name ,
522
517
)
523
518
524
519
self .logger .experiment .log (
525
- {
526
- f"val_forecast_samples/batch_idx_{ accum_batch_num } " : wandb .Image (fig ),
527
- }
520
+ {
521
+ f"val_forecast_samples/batch_idx_{ accum_batch_num } " : wandb .Image (fig ),
522
+ }
528
523
)
529
524
del self ._val_y_hats
530
525
del self ._val_batches
0 commit comments