Skip to content

Commit

Permalink
Address more PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Feb 6, 2024
1 parent d6a1035 commit d1adef2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 17 deletions.
16 changes: 2 additions & 14 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,6 @@ def __init__(
# Number of timestemps for 30 minutely data
self.history_len = history_minutes // interval_minutes
self.forecast_len = forecast_minutes // interval_minutes
# self.forecast_len_15 = forecast_minutes // 15
# self.history_len_15 = history_minutes // 15

self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len)

Expand Down Expand Up @@ -389,13 +387,9 @@ def _calculate_val_losses(self, y, y_hat):

# Take median value for remaining metric calculations
y_hat = self._quantiles_to_prediction(y_hat)
common_metrics_each_step = {
"mae": torch.mean(torch.abs(y_hat - y), dim=0),
"rmse": torch.sqrt(torch.mean((y_hat - y) ** 2, dim=0)),
}
# common_metrics_each_step = common_metrics(predictions=y_hat.numpy(), target=y.numpy())
mse_each_step = common_metrics_each_step["rmse"] ** 2
mae_each_step = common_metrics_each_step["mae"]
mse_each_step = torch.sqrt(torch.mean((y_hat - y) ** 2, dim=0)) ** 2
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0)

losses.update({f"MSE_horizon/step_{i:03}": m for i, m in enumerate(mse_each_step)})
losses.update({f"MAE_horizon/step_{i:03}": m for i, m in enumerate(mae_each_step)})
Expand Down Expand Up @@ -452,8 +446,6 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):

def training_step(self, batch, batch_idx):
"""Run training step"""
# Make all -1 values 0.0
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
y_hat = self(batch)
y = batch[self._target_key][:, -self.forecast_len :, 0]

Expand All @@ -470,8 +462,6 @@ def training_step(self, batch, batch_idx):

def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""
# Make all -1 values 0.0
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]
Expand Down Expand Up @@ -558,8 +548,6 @@ def validation_step(self, batch: dict, batch_idx):

def test_step(self, batch, batch_idx):
"""Run test step"""
# Make all -1 values 0.0
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
y_hat = self(batch)
y = batch[self._target_key][:, -self.forecast_len :, 0]

Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def flush(self) -> dict[BatchKey, list[torch.Tensor]]:
"""Concatenate all accumulated batches, return, and clear self"""
batch = {}
for k, v in self._batches.items():
if k == BatchKey.gsp_t0_idx or k == BatchKey.wind_t0_idx or k == BatchKey.pv_t0_idx:
if k == f"{self.key_to_keep}_t0_idx":
batch[k] = v[0]
else:
batch[k] = torch.cat(v, dim=0)
Expand Down
2 changes: 0 additions & 2 deletions scripts/save_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
# Tired of seeing these warnings
import warnings

import dask
import hydra
import torch
from ocf_datapipes.batch import stack_np_examples_into_batch
Expand All @@ -39,7 +38,6 @@
from pvnet.data.utils import batch_to_tensor
from pvnet.utils import print_config

dask.config.set(scheduler="single-threaded")

warnings.filterwarnings("ignore", category=sa_exc.SAWarning)

Expand Down

0 comments on commit d1adef2

Please sign in to comment.