Skip to content

Commit d1adef2

Browse files
committed
Address more PR comments
1 parent d6a1035 commit d1adef2

File tree

3 files changed

+3
-17
lines changed

3 files changed

+3
-17
lines changed

pvnet/models/base_model.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,6 @@ def __init__(
277277
# Number of timestemps for 30 minutely data
278278
self.history_len = history_minutes // interval_minutes
279279
self.forecast_len = forecast_minutes // interval_minutes
280-
# self.forecast_len_15 = forecast_minutes // 15
281-
# self.history_len_15 = history_minutes // 15
282280

283281
self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len)
284282

@@ -389,13 +387,9 @@ def _calculate_val_losses(self, y, y_hat):
389387

390388
# Take median value for remaining metric calculations
391389
y_hat = self._quantiles_to_prediction(y_hat)
392-
common_metrics_each_step = {
393-
"mae": torch.mean(torch.abs(y_hat - y), dim=0),
394-
"rmse": torch.sqrt(torch.mean((y_hat - y) ** 2, dim=0)),
395-
}
396390
# common_metrics_each_step = common_metrics(predictions=y_hat.numpy(), target=y.numpy())
397-
mse_each_step = common_metrics_each_step["rmse"] ** 2
398-
mae_each_step = common_metrics_each_step["mae"]
391+
mse_each_step = torch.sqrt(torch.mean((y_hat - y) ** 2, dim=0)) ** 2
392+
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0)
399393

400394
losses.update({f"MSE_horizon/step_{i:03}": m for i, m in enumerate(mse_each_step)})
401395
losses.update({f"MAE_horizon/step_{i:03}": m for i, m in enumerate(mae_each_step)})
@@ -452,8 +446,6 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
452446

453447
def training_step(self, batch, batch_idx):
454448
"""Run training step"""
455-
# Make all -1 values 0.0
456-
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
457449
y_hat = self(batch)
458450
y = batch[self._target_key][:, -self.forecast_len :, 0]
459451

@@ -470,8 +462,6 @@ def training_step(self, batch, batch_idx):
470462

471463
def validation_step(self, batch: dict, batch_idx):
472464
"""Run validation step"""
473-
# Make all -1 values 0.0
474-
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
475465
y_hat = self(batch)
476466
# Sensor seems to be in batch, station, time order
477467
y = batch[self._target_key][:, -self.forecast_len :, 0]
@@ -558,8 +548,6 @@ def validation_step(self, batch: dict, batch_idx):
558548

559549
def test_step(self, batch, batch_idx):
560550
"""Run test step"""
561-
# Make all -1 values 0.0
562-
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
563551
y_hat = self(batch)
564552
y = batch[self._target_key][:, -self.forecast_len :, 0]
565553

pvnet/models/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def flush(self) -> dict[BatchKey, list[torch.Tensor]]:
118118
"""Concatenate all accumulated batches, return, and clear self"""
119119
batch = {}
120120
for k, v in self._batches.items():
121-
if k == BatchKey.gsp_t0_idx or k == BatchKey.wind_t0_idx or k == BatchKey.pv_t0_idx:
121+
if k == f"{self.key_to_keep}_t0_idx":
122122
batch[k] = v[0]
123123
else:
124124
batch[k] = torch.cat(v, dim=0)

scripts/save_batches.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
# Tired of seeing these warnings
2424
import warnings
2525

26-
import dask
2726
import hydra
2827
import torch
2928
from ocf_datapipes.batch import stack_np_examples_into_batch
@@ -39,7 +38,6 @@
3938
from pvnet.data.utils import batch_to_tensor
4039
from pvnet.utils import print_config
4140

42-
dask.config.set(scheduler="single-threaded")
4341

4442
warnings.filterwarnings("ignore", category=sa_exc.SAWarning)
4543

0 commit comments

Comments
 (0)