Skip to content

Commit d436b64

Browse files
authored
Merge pull request #176 from openclimatefix/horizon_graph_gpu_fix
Fix the horizon graph training on GPU
2 parents 9c34c36 + 1b5283f commit d436b64

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pvnet/models/base_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ def validation_step(self, batch: dict, batch_idx):
521521

522522
# Store these to make horizon accuracy plot
523523
self._horizon_maes.append(
524-
{i: losses[f"MAE_horizon/step_{i:03}"] for i in range(self.forecast_len)}
524+
{i: losses[f"MAE_horizon/step_{i:03}"].cpu().numpy() for i in range(self.forecast_len)}
525525
)
526526

527527
logged_losses = {f"{k}/val": v for k, v in losses.items()}

0 commit comments

Comments
 (0)