Skip to content

Commit d6a1035

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1e7b648 commit d6a1035

File tree

4 files changed

+10
-8
lines changed

4 files changed

+10
-8
lines changed

pvnet/models/base_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def __init__(
241241
output_quantiles: Optional[list[float]] = None,
242242
target_key: str = "gsp",
243243
interval_minutes: int = 30,
244-
timestep_intervals_to_plot: Optional[list[int]] = None
244+
timestep_intervals_to_plot: Optional[list[int]] = None,
245245
):
246246
"""Abtstract base class for PVNet submodels.
247247
@@ -261,7 +261,9 @@ def __init__(
261261
self._target_key = BatchKey[f"{target_key}"]
262262
if timestep_intervals_to_plot is not None:
263263
for interval in timestep_intervals_to_plot:
264-
assert type(interval) in [list, tuple] and len(interval) == 2, ValueError(f"timestep_intervals_to_plot must be a list of tuples or lists of length 2, but got {timestep_intervals_to_plot=}")
264+
assert type(interval) in [list, tuple] and len(interval) == 2, ValueError(
265+
f"timestep_intervals_to_plot must be a list of tuples or lists of length 2, but got {timestep_intervals_to_plot=}"
266+
)
265267
self.time_step_intervals_to_plot = timestep_intervals_to_plot
266268

267269
# Model must have lr to allow tuning
@@ -453,7 +455,7 @@ def training_step(self, batch, batch_idx):
453455
# Make all -1 values 0.0
454456
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
455457
y_hat = self(batch)
456-
y = batch[self._target_key][:, -self.forecast_len:, 0]
458+
y = batch[self._target_key][:, -self.forecast_len :, 0]
457459

458460
losses = self._calculate_common_losses(y, y_hat)
459461
losses = {f"{k}/train": v for k, v in losses.items()}
@@ -472,7 +474,7 @@ def validation_step(self, batch: dict, batch_idx):
472474
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
473475
y_hat = self(batch)
474476
# Sensor seems to be in batch, station, time order
475-
y = batch[self._target_key][:, -self.forecast_len:, 0]
477+
y = batch[self._target_key][:, -self.forecast_len :, 0]
476478

477479
losses = self._calculate_common_losses(y, y_hat)
478480
losses.update(self._calculate_val_losses(y, y_hat))
@@ -559,7 +561,7 @@ def test_step(self, batch, batch_idx):
559561
# Make all -1 values 0.0
560562
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
561563
y_hat = self(batch)
562-
y = batch[self._target_key][:, -self.forecast_len:, 0]
564+
y = batch[self._target_key][:, -self.forecast_len :, 0]
563565

564566
losses = self._calculate_common_losses(y, y_hat)
565567
losses.update(self._calculate_val_losses(y, y_hat))

pvnet/models/multimodal/deep_supervision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def multi_mode_forward(self, x):
266266
def training_step(self, batch, batch_idx):
267267
"""Training step"""
268268
y_hats = self.multi_mode_forward(batch)
269-
y = batch[BatchKey.gsp][:, -self.forecast_len:, 0]
269+
y = batch[BatchKey.gsp][:, -self.forecast_len :, 0]
270270

271271
losses = self._calculate_common_losses(y, y_hats["all"])
272272
losses = {f"{k}/train": v for k, v in losses.items()}

pvnet/models/multimodal/multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __init__(
133133
output_quantiles=output_quantiles,
134134
target_key=target_key,
135135
interval_minutes=interval_minutes,
136-
timestep_intervals_to_plot=timestep_intervals_to_plot
136+
timestep_intervals_to_plot=timestep_intervals_to_plot,
137137
)
138138

139139
# Number of features expected by the output_network

pvnet/models/multimodal/weather_residual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def multi_mode_forward(self, x):
295295
def training_step(self, batch, batch_idx):
296296
"""Run training step"""
297297
y_hats = self.multi_mode_forward(batch)
298-
y = batch[BatchKey.gsp][:, -self.forecast_len:, 0]
298+
y = batch[BatchKey.gsp][:, -self.forecast_len :, 0]
299299

300300
losses = self._calculate_common_losses(y, y_hats["weather_out"])
301301
losses = {f"{k}/train": v for k, v in losses.items()}

0 commit comments

Comments
 (0)