@@ -277,8 +277,6 @@ def __init__(
277277 # Number of timestemps for 30 minutely data
278278 self .history_len = history_minutes // interval_minutes
279279 self .forecast_len = forecast_minutes // interval_minutes
280- # self.forecast_len_15 = forecast_minutes // 15
281- # self.history_len_15 = history_minutes // 15
282280
283281 self .weighted_losses = WeightedLosses (forecast_length = self .forecast_len )
284282
@@ -389,13 +387,9 @@ def _calculate_val_losses(self, y, y_hat):
389387
390388 # Take median value for remaining metric calculations
391389 y_hat = self ._quantiles_to_prediction (y_hat )
392- common_metrics_each_step = {
393- "mae" : torch .mean (torch .abs (y_hat - y ), dim = 0 ),
394- "rmse" : torch .sqrt (torch .mean ((y_hat - y ) ** 2 , dim = 0 )),
395- }
396390 # common_metrics_each_step = common_metrics(predictions=y_hat.numpy(), target=y.numpy())
397- mse_each_step = common_metrics_each_step [ "rmse" ] ** 2
398- mae_each_step = common_metrics_each_step [ "mae" ]
391+ mse_each_step = torch . sqrt ( torch . mean (( y_hat - y ) ** 2 , dim = 0 )) ** 2
392+ mae_each_step = torch . mean ( torch . abs ( y_hat - y ), dim = 0 )
399393
400394 losses .update ({f"MSE_horizon/step_{ i :03} " : m for i , m in enumerate (mse_each_step )})
401395 losses .update ({f"MAE_horizon/step_{ i :03} " : m for i , m in enumerate (mae_each_step )})
@@ -452,8 +446,6 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
452446
453447 def training_step (self , batch , batch_idx ):
454448 """Run training step"""
455- # Make all -1 values 0.0
456- batch [self ._target_key ] = batch [self ._target_key ].clamp (min = 0.0 )
457449 y_hat = self (batch )
458450 y = batch [self ._target_key ][:, - self .forecast_len :, 0 ]
459451
@@ -470,8 +462,6 @@ def training_step(self, batch, batch_idx):
470462
471463 def validation_step (self , batch : dict , batch_idx ):
472464 """Run validation step"""
473- # Make all -1 values 0.0
474- batch [self ._target_key ] = batch [self ._target_key ].clamp (min = 0.0 )
475465 y_hat = self (batch )
476466 # Sensor seems to be in batch, station, time order
477467 y = batch [self ._target_key ][:, - self .forecast_len :, 0 ]
@@ -558,8 +548,6 @@ def validation_step(self, batch: dict, batch_idx):
558548
559549 def test_step (self , batch , batch_idx ):
560550 """Run test step"""
561- # Make all -1 values 0.0
562- batch [self ._target_key ] = batch [self ._target_key ].clamp (min = 0.0 )
563551 y_hat = self (batch )
564552 y = batch [self ._target_key ][:, - self .forecast_len :, 0 ]
565553
0 commit comments