Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 6, 2024
1 parent 1e7b648 commit d6a1035
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
12 changes: 7 additions & 5 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def __init__(
output_quantiles: Optional[list[float]] = None,
target_key: str = "gsp",
interval_minutes: int = 30,
timestep_intervals_to_plot: Optional[list[int]] = None
timestep_intervals_to_plot: Optional[list[int]] = None,
):
"""Abtstract base class for PVNet submodels.
Expand All @@ -261,7 +261,9 @@ def __init__(
self._target_key = BatchKey[f"{target_key}"]
if timestep_intervals_to_plot is not None:
for interval in timestep_intervals_to_plot:
assert type(interval) in [list, tuple] and len(interval) == 2, ValueError(f"timestep_intervals_to_plot must be a list of tuples or lists of length 2, but got {timestep_intervals_to_plot=}")
assert type(interval) in [list, tuple] and len(interval) == 2, ValueError(
f"timestep_intervals_to_plot must be a list of tuples or lists of length 2, but got {timestep_intervals_to_plot=}"
)
self.time_step_intervals_to_plot = timestep_intervals_to_plot

# Model must have lr to allow tuning
Expand Down Expand Up @@ -453,7 +455,7 @@ def training_step(self, batch, batch_idx):
# Make all -1 values 0.0
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
y_hat = self(batch)
y = batch[self._target_key][:, -self.forecast_len:, 0]
y = batch[self._target_key][:, -self.forecast_len :, 0]

losses = self._calculate_common_losses(y, y_hat)
losses = {f"{k}/train": v for k, v in losses.items()}
Expand All @@ -472,7 +474,7 @@ def validation_step(self, batch: dict, batch_idx):
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len:, 0]
y = batch[self._target_key][:, -self.forecast_len :, 0]

losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand Down Expand Up @@ -559,7 +561,7 @@ def test_step(self, batch, batch_idx):
# Make all -1 values 0.0
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
y_hat = self(batch)
y = batch[self._target_key][:, -self.forecast_len:, 0]
y = batch[self._target_key][:, -self.forecast_len :, 0]

losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/multimodal/deep_supervision.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def multi_mode_forward(self, x):
def training_step(self, batch, batch_idx):
"""Training step"""
y_hats = self.multi_mode_forward(batch)
y = batch[BatchKey.gsp][:, -self.forecast_len:, 0]
y = batch[BatchKey.gsp][:, -self.forecast_len :, 0]

losses = self._calculate_common_losses(y, y_hats["all"])
losses = {f"{k}/train": v for k, v in losses.items()}
Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
output_quantiles=output_quantiles,
target_key=target_key,
interval_minutes=interval_minutes,
timestep_intervals_to_plot=timestep_intervals_to_plot
timestep_intervals_to_plot=timestep_intervals_to_plot,
)

# Number of features expected by the output_network
Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/multimodal/weather_residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def multi_mode_forward(self, x):
def training_step(self, batch, batch_idx):
"""Run training step"""
y_hats = self.multi_mode_forward(batch)
y = batch[BatchKey.gsp][:, -self.forecast_len:, 0]
y = batch[BatchKey.gsp][:, -self.forecast_len :, 0]

losses = self._calculate_common_losses(y, y_hats["weather_out"])
losses = {f"{k}/train": v for k, v in losses.items()}
Expand Down

0 comments on commit d6a1035

Please sign in to comment.