Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 19, 2024
1 parent c14375b commit 901dee8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 31 deletions.
53 changes: 24 additions & 29 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@
import pandas as pd
import torch
import torch.nn.functional as F
import yaml
import wandb
import yaml
from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils._deprecation import _deprecate_positional_args
from ocf_datapipes.batch import BatchKey
from ocf_ml_metrics.evaluation.evaluation import evaluation
from ocf_ml_metrics.metrics.errors import common_metrics

from pvnet.models.utils import (
BatchAccumulator,
Expand Down Expand Up @@ -275,9 +274,7 @@ def __init__(
self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30)

self._accumulated_metrics = MetricAccumulator()
self._accumulated_batches = BatchAccumulator(
key_to_keep=self._target_key_name
)
self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key_name)
self._accumulated_y_hat = PredAccumulator()

@property
Expand Down Expand Up @@ -434,14 +431,14 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
# We only create the figure every 8 log steps
# This was reduced as it was creating figures too often
if grad_batch_num % (8 * self.trainer.log_every_n_steps) == 0:
fig = plot_batch_forecasts(
batch,
y_hat,
batch_idx,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
)
fig.savefig("latest_logged_train_batch.png")
fig = plot_batch_forecasts(
batch,
y_hat,
batch_idx,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
)
fig.savefig("latest_logged_train_batch.png")

def training_step(self, batch, batch_idx):
"""Run training step"""
Expand Down Expand Up @@ -477,18 +474,18 @@ def validation_step(self, batch: dict, batch_idx):
# for each step in the forecast horizon
# This is needed for the custom plot
# And needs to be in order of step
x_values = [int(k.split("_")[-1].split("/")[0]) for k in logged_losses.keys() if "MAE_horizon/step" in k]
x_values = [
int(k.split("_")[-1].split("/")[0])
for k in logged_losses.keys()
if "MAE_horizon/step" in k
]
y_values = []
for x in x_values:
y_values.append(logged_losses[f"MAE_horizon/step_{x:03}/val"])
per_step_losses = [[x, y] for (x, y) in zip(x_values, y_values)]
table = wandb.Table(data=per_step_losses, columns=["timestep", "MAE"])
wandb.log(
{
"mae_vs_timestep": wandb.plot.line(
table, "timestep", "MAE", title="MAE vs Timestep"
)
}
{"mae_vs_timestep": wandb.plot.line(table, "timestep", "MAE", title="MAE vs Timestep")}
)

self.log_dict(
Expand All @@ -503,9 +500,7 @@ def validation_step(self, batch: dict, batch_idx):
# Store these temporarily under self
if not hasattr(self, "_val_y_hats"):
self._val_y_hats = PredAccumulator()
self._val_batches = BatchAccumulator(
key_to_keep=self._target_key_name
)
self._val_batches = BatchAccumulator(key_to_keep=self._target_key_name)

self._val_y_hats.append(y_hat)
self._val_batches.append(batch)
Expand All @@ -515,16 +510,16 @@ def validation_step(self, batch: dict, batch_idx):
batch = self._val_batches.flush()

fig = plot_batch_forecasts(
batch,
y_hat,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
batch,
y_hat,
quantiles=self.output_quantiles,
key_to_plot=self._target_key_name,
)

self.logger.experiment.log(
{
f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig),
}
{
f"val_forecast_samples/batch_idx_{accum_batch_num}": wandb.Image(fig),
}
)
del self._val_y_hats
del self._val_batches
Expand Down
4 changes: 2 additions & 2 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
wind_history_minutes: Optional[int] = None,
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
target_key: str = "gsp",
interval_minutes: int = 30,
interval_minutes: int = 30,
):
"""Neural network which combines information from different sources.
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
optimizer=optimizer,
output_quantiles=output_quantiles,
target_key=target_key,
interval_minutes=interval_minutes
interval_minutes=interval_minutes,
)

# Number of features expected by the output_network
Expand Down

0 comments on commit 901dee8

Please sign in to comment.