@@ -277,8 +277,6 @@ def __init__(
277
277
# Number of timestemps for 30 minutely data
278
278
self .history_len = history_minutes // interval_minutes
279
279
self .forecast_len = forecast_minutes // interval_minutes
280
- # self.forecast_len_15 = forecast_minutes // 15
281
- # self.history_len_15 = history_minutes // 15
282
280
283
281
self .weighted_losses = WeightedLosses (forecast_length = self .forecast_len )
284
282
@@ -389,13 +387,9 @@ def _calculate_val_losses(self, y, y_hat):
389
387
390
388
# Take median value for remaining metric calculations
391
389
y_hat = self ._quantiles_to_prediction (y_hat )
392
- common_metrics_each_step = {
393
- "mae" : torch .mean (torch .abs (y_hat - y ), dim = 0 ),
394
- "rmse" : torch .sqrt (torch .mean ((y_hat - y ) ** 2 , dim = 0 )),
395
- }
396
390
# common_metrics_each_step = common_metrics(predictions=y_hat.numpy(), target=y.numpy())
397
- mse_each_step = common_metrics_each_step [ "rmse" ] ** 2
398
- mae_each_step = common_metrics_each_step [ "mae" ]
391
+ mse_each_step = torch . sqrt ( torch . mean (( y_hat - y ) ** 2 , dim = 0 )) ** 2
392
+ mae_each_step = torch . mean ( torch . abs ( y_hat - y ), dim = 0 )
399
393
400
394
losses .update ({f"MSE_horizon/step_{ i :03} " : m for i , m in enumerate (mse_each_step )})
401
395
losses .update ({f"MAE_horizon/step_{ i :03} " : m for i , m in enumerate (mae_each_step )})
@@ -452,8 +446,6 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
452
446
453
447
def training_step (self , batch , batch_idx ):
454
448
"""Run training step"""
455
- # Make all -1 values 0.0
456
- batch [self ._target_key ] = batch [self ._target_key ].clamp (min = 0.0 )
457
449
y_hat = self (batch )
458
450
y = batch [self ._target_key ][:, - self .forecast_len :, 0 ]
459
451
@@ -470,8 +462,6 @@ def training_step(self, batch, batch_idx):
470
462
471
463
def validation_step (self , batch : dict , batch_idx ):
472
464
"""Run validation step"""
473
- # Make all -1 values 0.0
474
- batch [self ._target_key ] = batch [self ._target_key ].clamp (min = 0.0 )
475
465
y_hat = self (batch )
476
466
# Sensor seems to be in batch, station, time order
477
467
y = batch [self ._target_key ][:, - self .forecast_len :, 0 ]
@@ -558,8 +548,6 @@ def validation_step(self, batch: dict, batch_idx):
558
548
559
549
def test_step (self , batch , batch_idx ):
560
550
"""Run test step"""
561
- # Make all -1 values 0.0
562
- batch [self ._target_key ] = batch [self ._target_key ].clamp (min = 0.0 )
563
551
y_hat = self (batch )
564
552
y = batch [self ._target_key ][:, - self .forecast_len :, 0 ]
565
553
0 commit comments