From d1adef29c54780955437902e872321779870e901 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 6 Feb 2024 08:17:49 +0000 Subject: [PATCH] Address more PR comments --- pvnet/models/base_model.py | 16 ++-------------- pvnet/models/utils.py | 2 +- scripts/save_batches.py | 2 -- 3 files changed, 3 insertions(+), 17 deletions(-) diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index ef3b8ed9..33d37127 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -277,8 +277,6 @@ def __init__( # Number of timestemps for 30 minutely data self.history_len = history_minutes // interval_minutes self.forecast_len = forecast_minutes // interval_minutes - # self.forecast_len_15 = forecast_minutes // 15 - # self.history_len_15 = history_minutes // 15 self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len) @@ -389,13 +387,9 @@ def _calculate_val_losses(self, y, y_hat): # Take median value for remaining metric calculations y_hat = self._quantiles_to_prediction(y_hat) - common_metrics_each_step = { - "mae": torch.mean(torch.abs(y_hat - y), dim=0), - "rmse": torch.sqrt(torch.mean((y_hat - y) ** 2, dim=0)), - } # common_metrics_each_step = common_metrics(predictions=y_hat.numpy(), target=y.numpy()) - mse_each_step = common_metrics_each_step["rmse"] ** 2 - mae_each_step = common_metrics_each_step["mae"] + mse_each_step = torch.sqrt(torch.mean((y_hat - y) ** 2, dim=0)) ** 2 + mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0) losses.update({f"MSE_horizon/step_{i:03}": m for i, m in enumerate(mse_each_step)}) losses.update({f"MAE_horizon/step_{i:03}": m for i, m in enumerate(mae_each_step)}) @@ -452,8 +446,6 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat): def training_step(self, batch, batch_idx): """Run training step""" - # Make all -1 values 0.0 - batch[self._target_key] = batch[self._target_key].clamp(min=0.0) y_hat = self(batch) y = batch[self._target_key][:, -self.forecast_len :, 0] @@ -470,8 +462,6 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch: dict, batch_idx): """Run validation step""" - # Make all -1 values 0.0 - batch[self._target_key] = batch[self._target_key].clamp(min=0.0) y_hat = self(batch) # Sensor seems to be in batch, station, time order y = batch[self._target_key][:, -self.forecast_len :, 0] @@ -558,8 +548,6 @@ def validation_step(self, batch: dict, batch_idx): def test_step(self, batch, batch_idx): """Run test step""" - # Make all -1 values 0.0 - batch[self._target_key] = batch[self._target_key].clamp(min=0.0) y_hat = self(batch) y = batch[self._target_key][:, -self.forecast_len :, 0] diff --git a/pvnet/models/utils.py b/pvnet/models/utils.py index 2511e0d4..8cb4b662 100644 --- a/pvnet/models/utils.py +++ b/pvnet/models/utils.py @@ -118,7 +118,7 @@ def flush(self) -> dict[BatchKey, list[torch.Tensor]]: """Concatenate all accumulated batches, return, and clear self""" batch = {} for k, v in self._batches.items(): - if k == BatchKey.gsp_t0_idx or k == BatchKey.wind_t0_idx or k == BatchKey.pv_t0_idx: + if k == f"{self.key_to_keep}_t0_idx": batch[k] = v[0] else: batch[k] = torch.cat(v, dim=0) diff --git a/scripts/save_batches.py b/scripts/save_batches.py index 54989e8e..b4d74e8f 100644 --- a/scripts/save_batches.py +++ b/scripts/save_batches.py @@ -23,7 +23,6 @@ # Tired of seeing these warnings import warnings -import dask import hydra import torch from ocf_datapipes.batch import stack_np_examples_into_batch @@ -39,7 +38,6 @@ from pvnet.data.utils import batch_to_tensor from pvnet.utils import print_config -dask.config.set(scheduler="single-threaded") warnings.filterwarnings("ignore", category=sa_exc.SAWarning)