Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 24, 2024
1 parent 6b3e504 commit 468b0b3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
18 changes: 11 additions & 7 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import Dict, Optional, Union

import hydra
import matplotlib.pyplot as plt
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -490,11 +490,11 @@ def validation_step(self, batch: dict, batch_idx):
{"mae_vs_timestep": wandb.plot.line(table, "timestep", "MAE", title="MAE vs Timestep")}
)

#self.log_dict(
# self.log_dict(
# logged_losses,
# on_step=False,
# on_epoch=True,
#)
# )

accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches

Expand Down Expand Up @@ -531,11 +531,13 @@ def validation_step(self, batch: dict, batch_idx):
y_hat,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
timesteps_to_plot=[6,12] # 1:30 to 3 hours ahead
timesteps_to_plot=[6, 12], # 1:30 to 3 hours ahead
)
self.logger.experiment.log(
{
f"val_forecast_samples/batch_idx_{accum_batch_num}_1.5_to_3hr": wandb.Image(fig),
f"val_forecast_samples/batch_idx_{accum_batch_num}_1.5_to_3hr": wandb.Image(
fig
),
}
)
plt.close(fig)
Expand All @@ -546,11 +548,13 @@ def validation_step(self, batch: dict, batch_idx):
y_hat,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
timesteps_to_plot=[60, 156] # 15 to 39 hours ahead
timesteps_to_plot=[60, 156], # 15 to 39 hours ahead
)
self.logger.experiment.log(
{
f"val_forecast_samples/batch_idx_{accum_batch_num}_15_to_39hr": wandb.Image(fig),
f"val_forecast_samples/batch_idx_{accum_batch_num}_15_to_39hr": wandb.Image(
fig
),
}
)
plt.close(fig)
Expand Down
13 changes: 10 additions & 3 deletions pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import os
import warnings
from collections.abc import Sequence

from typing import Optional

import lightning.pytorch as pl
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -243,15 +243,22 @@ def finish(
wandb.finish()


def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None, key_to_plot: str = "gsp", timesteps_to_plot: Optional[list[int]] = None):
def plot_batch_forecasts(
batch,
y_hat,
batch_idx=None,
quantiles=None,
key_to_plot: str = "gsp",
timesteps_to_plot: Optional[list[int]] = None,
):
"""Plot a batch of data and the forecast from that batch"""

def _get_numpy(key):
return batch[key].cpu().numpy().squeeze()

y_key = BatchKey[f"{key_to_plot}"]
y_id_key = BatchKey[f"{key_to_plot}_id"]
t0_idx_key = BatchKey[f"{key_to_plot}_t0_idx"]
BatchKey[f"{key_to_plot}_t0_idx"]
time_utc_key = BatchKey[f"{key_to_plot}_time_utc"]
y = batch[y_key][:, :, 0].cpu().numpy() # Select the one it is trained on
y_hat = y_hat.cpu().numpy()
Expand Down

0 comments on commit 468b0b3

Please sign in to comment.