@@ -45,9 +45,9 @@ def __init__(
4545 self .lr = None
4646
4747 def transfer_batch_to_device (
48- self ,
49- batch : TensorBatch ,
50- device : torch .device ,
48+ self ,
49+ batch : TensorBatch ,
50+ device : torch .device ,
5151 dataloader_idx : int ,
5252 ) -> dict :
5353 """Method to move custom batches to a given device"""
@@ -75,7 +75,7 @@ def _calculate_quantile_loss(self, y_quantiles: torch.Tensor, y: torch.Tensor) -
7575 losses = 2 * torch .cat (losses , dim = 2 )
7676
7777 return losses .mean ()
78-
78+
7979 def configure_optimizers (self ):
8080 """Configure the optimizers using learning rate found with LR finder if used"""
8181 if self .lr is not None :
@@ -84,7 +84,7 @@ def configure_optimizers(self):
8484 return self ._optimizer (self .model )
8585
8686 def _calculate_common_losses (
87- self ,
87+ self ,
8888 y : torch .Tensor ,
8989 y_hat : torch .Tensor ,
9090 ) -> dict [str , torch .Tensor ]:
@@ -96,15 +96,15 @@ def _calculate_common_losses(
9696 losses ["quantile_loss" ] = self ._calculate_quantile_loss (y_hat , y )
9797 y_hat = self .model ._quantiles_to_prediction (y_hat )
9898
99- losses .update ({"MSE" : F .mse_loss (y_hat , y ), "MAE" : F .l1_loss (y_hat , y )})
99+ losses .update ({"MSE" : F .mse_loss (y_hat , y ), "MAE" : F .l1_loss (y_hat , y )})
100100
101101 return losses
102-
102+
103103 def training_step (self , batch : TensorBatch , batch_idx : int ) -> torch .Tensor :
104104 """Run training step"""
105105 y_hat = self .model (batch )
106106
107- y = batch [self . model . _target_key ][:, - self .model .forecast_len :]
107+ y = batch ["generation" ][:, - self .model .forecast_len :]
108108
109109 losses = self ._calculate_common_losses (y , y_hat )
110110 losses = {f"{ k } /train" : v for k , v in losses .items ()}
@@ -116,10 +116,10 @@ def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor:
116116 else :
117117 opt_target = losses ["MAE/train" ]
118118 return opt_target
119-
119+
120120 def _calculate_val_losses (
121- self ,
122- y : torch .Tensor ,
121+ self ,
122+ y : torch .Tensor ,
123123 y_hat : torch .Tensor ,
124124 ) -> dict [str , torch .Tensor ]:
125125 """Calculate additional losses only run in validation"""
@@ -138,28 +138,25 @@ def _calculate_val_losses(
138138 return losses
139139
140140 def _calculate_step_metrics (
141- self ,
142- y : torch .Tensor ,
143- y_hat : torch .Tensor ,
141+ self ,
142+ y : torch .Tensor ,
143+ y_hat : torch .Tensor ,
144144 ) -> tuple [np .array , np .array ]:
145145 """Calculate the MAE and MSE at each forecast step"""
146146
147147 mae_each_step = torch .mean (torch .abs (y_hat - y ), dim = 0 ).cpu ().numpy ()
148148 mse_each_step = torch .mean ((y_hat - y ) ** 2 , dim = 0 ).cpu ().numpy ()
149-
149+
150150 return mae_each_step , mse_each_step
151-
151+
152152 def _store_val_predictions (self , batch : TensorBatch , y_hat : torch .Tensor ) -> None :
153153 """Internally store the validation predictions"""
154-
155- target_key = self .model ._target_key
156154
157- y = batch [target_key ][:, - self .model .forecast_len :].cpu ().numpy ()
158- y_hat = y_hat .cpu ().numpy ()
155+ y = batch ["generation" ][:, - self .model .forecast_len :].cpu ().numpy ()
156+ y_hat = y_hat .cpu ().numpy ()
159157 ids = batch ["location_id" ].cpu ().numpy ()
160158 init_times_utc = pd .to_datetime (
161- batch ["time_utc" ][:, self .model .history_len + 1 ]
162- .cpu ().numpy ().astype ("datetime64[ns]" )
159+ batch ["time_utc" ][:, self .model .history_len + 1 ].cpu ().numpy ().astype ("datetime64[ns]" )
163160 )
164161
165162 if self .model .use_quantile_regression :
@@ -170,7 +167,7 @@ def _store_val_predictions(self, batch: TensorBatch, y_hat: torch.Tensor) -> Non
170167
171168 ds_preds_batch = xr .Dataset (
172169 data_vars = dict (
173- y_hat = (["sample_num" , "forecast_step" , "p_level" ], y_hat ),
170+ y_hat = (["sample_num" , "forecast_step" , "p_level" ], y_hat ),
174171 y = (["sample_num" , "forecast_step" ], y ),
175172 ),
176173 coords = dict (
@@ -186,7 +183,7 @@ def on_validation_epoch_start(self):
186183 # Set up stores which we will fill during validation
187184 self .all_val_results : list [xr .Dataset ] = []
188185 self ._val_horizon_maes : list [np .array ] = []
189- if self .current_epoch == 0 :
186+ if self .current_epoch == 0 :
190187 self ._val_persistence_horizon_maes : list [np .array ] = []
191188
192189 # Plot some sample forecasts
@@ -197,29 +194,26 @@ def on_validation_epoch_start(self):
197194
198195 for plot_num in range (num_figures ):
199196 idxs = np .arange (plots_per_figure ) + plot_num * plots_per_figure
200- idxs = idxs [idxs < len (val_dataset )]
197+ idxs = idxs [idxs < len (val_dataset )]
201198
202- if len (idxs )== 0 :
199+ if len (idxs ) == 0 :
203200 continue
204201
205202 batch = collate_fn ([val_dataset [i ] for i in idxs ])
206203 batch = self .transfer_batch_to_device (batch , self .device , dataloader_idx = 0 )
207204
208205 # Batch validation check only during sanity check phase - use first batch
209206 if self .trainer .sanity_checking and plot_num == 0 :
210- validate_batch_against_config (
211- batch = batch ,
212- model = self .model
213- )
214-
207+ validate_batch_against_config (batch = batch , model = self .model )
208+
215209 with torch .no_grad ():
216210 y_hat = self .model (batch )
217-
211+
218212 fig = plot_sample_forecasts (
219213 batch ,
220214 y_hat ,
221215 quantiles = self .model .output_quantiles ,
222- key_to_plot = self . model . _target_key ,
216+ key_to_plot = "generation" ,
223217 )
224218
225219 plot_name = f"val_forecast_samples/sample_set_{ plot_num } "
@@ -238,7 +232,7 @@ def validation_step(self, batch: TensorBatch, batch_idx: int) -> None:
238232 # Internally store the val predictions
239233 self ._store_val_predictions (batch , y_hat )
240234
241- y = batch [self . model . _target_key ][:, - self .model .forecast_len :]
235+ y = batch ["generation" ][:, - self .model .forecast_len :]
242236
243237 losses = self ._calculate_common_losses (y , y_hat )
244238 losses = {f"{ k } /val" : v for k , v in losses .items ()}
@@ -262,21 +256,22 @@ def validation_step(self, batch: TensorBatch, batch_idx: int) -> None:
262256
263257 # Calculate the persistance losses - we only need to do this once per training run
264258 # not every epoch
265- if self .current_epoch == 0 :
259+ if self .current_epoch == 0 :
266260 y_persist = (
267- batch [self .model ._target_key ][:, - (self .model .forecast_len + 1 )]
268- .unsqueeze (1 ).expand (- 1 , self .model .forecast_len )
261+ batch ["generation" ][:, - (self .model .forecast_len + 1 )]
262+ .unsqueeze (1 )
263+ .expand (- 1 , self .model .forecast_len )
269264 )
270265 mae_step_persist , mse_step_persist = self ._calculate_step_metrics (y , y_persist )
271266 self ._val_persistence_horizon_maes .append (mae_step_persist )
272267 losses .update (
273268 {
274- "MAE/val_persistence" : mae_step_persist .mean (),
275- "MSE/val_persistence" : mse_step_persist .mean ()
269+ "MAE/val_persistence" : mae_step_persist .mean (),
270+ "MSE/val_persistence" : mse_step_persist .mean (),
276271 }
277272 )
278273
279- # Log the metrics
274+ # Log the metrics
280275 self .log_dict (losses , on_step = False , on_epoch = True )
281276
282277 def on_validation_epoch_end (self ) -> None :
@@ -289,7 +284,7 @@ def on_validation_epoch_end(self) -> None:
289284 self ._val_horizon_maes = []
290285
291286 # We only run this on the first epoch
292- if self .current_epoch == 0 :
287+ if self .current_epoch == 0 :
293288 val_persistence_horizon_maes = np .mean (self ._val_persistence_horizon_maes , axis = 0 )
294289 self ._val_persistence_horizon_maes = []
295290
@@ -321,25 +316,25 @@ def on_validation_epoch_end(self) -> None:
321316 wandb_log_dir = self .logger .experiment .dir
322317 filepath = f"{ wandb_log_dir } /validation_results.netcdf"
323318 ds_val_results .to_netcdf (filepath )
324-
325- # Uplodad to wandb
319+
320+ # Uplodad to wandb
326321 self .logger .experiment .save (filepath , base_path = wandb_log_dir , policy = "now" )
327-
322+
328323 # Create the horizon accuracy curve
329324 horizon_mae_plot = wandb_line_plot (
330- x = np .arange (self .model .forecast_len ),
325+ x = np .arange (self .model .forecast_len ),
331326 y = val_horizon_maes ,
332327 xlabel = "Horizon step" ,
333328 ylabel = "MAE" ,
334329 title = "Val horizon loss curve" ,
335330 )
336-
331+
337332 wandb .log ({"val_horizon_mae_plot" : horizon_mae_plot })
338333
339334 # Create persistence horizon accuracy curve but only on first epoch
340- if self .current_epoch == 0 :
335+ if self .current_epoch == 0 :
341336 persist_horizon_mae_plot = wandb_line_plot (
342- x = np .arange (self .model .forecast_len ),
337+ x = np .arange (self .model .forecast_len ),
343338 y = val_persistence_horizon_maes ,
344339 xlabel = "Horizon step" ,
345340 ylabel = "MAE" ,
0 commit comments