Skip to content

Commit fc89eaa

Browse files
committed
Remove target key attribute
1 parent f8400ec commit fc89eaa

File tree

4 files changed

+73
-71
lines changed

4 files changed

+73
-71
lines changed

pvnet/models/base_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,9 +411,6 @@ def __init__(
411411
"""
412412
super().__init__()
413413

414-
# The key of the target variable in the batch
415-
self._target_key = "generation"
416-
417414
self.history_minutes = history_minutes
418415
self.forecast_minutes = forecast_minutes
419416
self.output_quantiles = output_quantiles

pvnet/models/late_fusion/late_fusion.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def __init__(
216216

217217
self.pv_encoder = pv_encoder(
218218
sequence_length=pv_history_minutes // pv_interval_minutes + 1,
219-
key_to_use=self._target_key,
219+
key_to_use="generation",
220220
)
221221

222222
# Update num features
@@ -296,15 +296,15 @@ def forward(self, x: TensorBatch) -> torch.Tensor:
296296
# *********************** Generation Data *************************************
297297
# Add generation yield history
298298
if self.include_generation_history:
299-
generation_history = x[self._target_key][:, : self.history_len + 1].float()
299+
generation_history = x["generation"][:, : self.history_len + 1].float()
300300
generation_history = generation_history.reshape(generation_history.shape[0], -1)
301-
modes[self._target_key] = generation_history
301+
modes["generation"] = generation_history
302302

303303
# Add location-level yield history through PV encoder
304304
if self.include_pv:
305305
x_tmp = x.copy()
306-
x_tmp[self._target_key] = x_tmp[self._target_key][:, : self.history_len + 1]
307-
modes[self._target_key] = self.pv_encoder(x_tmp)
306+
x_tmp["generation"] = x_tmp["generation"][:, : self.history_len + 1]
307+
modes["generation"] = self.pv_encoder(x_tmp)
308308

309309
# ********************** Embedding of location ID ********************
310310
if self.use_id_embedding:

pvnet/training/lightning_module.py

Lines changed: 43 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def __init__(
4545
self.lr = None
4646

4747
def transfer_batch_to_device(
48-
self,
49-
batch: TensorBatch,
50-
device: torch.device,
48+
self,
49+
batch: TensorBatch,
50+
device: torch.device,
5151
dataloader_idx: int,
5252
) -> dict:
5353
"""Method to move custom batches to a given device"""
@@ -75,7 +75,7 @@ def _calculate_quantile_loss(self, y_quantiles: torch.Tensor, y: torch.Tensor) -
7575
losses = 2 * torch.cat(losses, dim=2)
7676

7777
return losses.mean()
78-
78+
7979
def configure_optimizers(self):
8080
"""Configure the optimizers using learning rate found with LR finder if used"""
8181
if self.lr is not None:
@@ -84,7 +84,7 @@ def configure_optimizers(self):
8484
return self._optimizer(self.model)
8585

8686
def _calculate_common_losses(
87-
self,
87+
self,
8888
y: torch.Tensor,
8989
y_hat: torch.Tensor,
9090
) -> dict[str, torch.Tensor]:
@@ -96,15 +96,15 @@ def _calculate_common_losses(
9696
losses["quantile_loss"] = self._calculate_quantile_loss(y_hat, y)
9797
y_hat = self.model._quantiles_to_prediction(y_hat)
9898

99-
losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)})
99+
losses.update({"MSE": F.mse_loss(y_hat, y), "MAE": F.l1_loss(y_hat, y)})
100100

101101
return losses
102-
102+
103103
def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor:
104104
"""Run training step"""
105105
y_hat = self.model(batch)
106106

107-
y = batch[self.model._target_key][:, -self.model.forecast_len :]
107+
y = batch["generation"][:, -self.model.forecast_len :]
108108

109109
losses = self._calculate_common_losses(y, y_hat)
110110
losses = {f"{k}/train": v for k, v in losses.items()}
@@ -116,10 +116,10 @@ def training_step(self, batch: TensorBatch, batch_idx: int) -> torch.Tensor:
116116
else:
117117
opt_target = losses["MAE/train"]
118118
return opt_target
119-
119+
120120
def _calculate_val_losses(
121-
self,
122-
y: torch.Tensor,
121+
self,
122+
y: torch.Tensor,
123123
y_hat: torch.Tensor,
124124
) -> dict[str, torch.Tensor]:
125125
"""Calculate additional losses only run in validation"""
@@ -138,28 +138,25 @@ def _calculate_val_losses(
138138
return losses
139139

140140
def _calculate_step_metrics(
141-
self,
142-
y: torch.Tensor,
143-
y_hat: torch.Tensor,
141+
self,
142+
y: torch.Tensor,
143+
y_hat: torch.Tensor,
144144
) -> tuple[np.array, np.array]:
145145
"""Calculate the MAE and MSE at each forecast step"""
146146

147147
mae_each_step = torch.mean(torch.abs(y_hat - y), dim=0).cpu().numpy()
148148
mse_each_step = torch.mean((y_hat - y) ** 2, dim=0).cpu().numpy()
149-
149+
150150
return mae_each_step, mse_each_step
151-
151+
152152
def _store_val_predictions(self, batch: TensorBatch, y_hat: torch.Tensor) -> None:
153153
"""Internally store the validation predictions"""
154-
155-
target_key = self.model._target_key
156154

157-
y = batch[target_key][:, -self.model.forecast_len :].cpu().numpy()
158-
y_hat = y_hat.cpu().numpy()
155+
y = batch["generation"][:, -self.model.forecast_len :].cpu().numpy()
156+
y_hat = y_hat.cpu().numpy()
159157
ids = batch["location_id"].cpu().numpy()
160158
init_times_utc = pd.to_datetime(
161-
batch["time_utc"][:, self.model.history_len+1]
162-
.cpu().numpy().astype("datetime64[ns]")
159+
batch["time_utc"][:, self.model.history_len + 1].cpu().numpy().astype("datetime64[ns]")
163160
)
164161

165162
if self.model.use_quantile_regression:
@@ -170,7 +167,7 @@ def _store_val_predictions(self, batch: TensorBatch, y_hat: torch.Tensor) -> Non
170167

171168
ds_preds_batch = xr.Dataset(
172169
data_vars=dict(
173-
y_hat=(["sample_num", "forecast_step", "p_level"], y_hat),
170+
y_hat=(["sample_num", "forecast_step", "p_level"], y_hat),
174171
y=(["sample_num", "forecast_step"], y),
175172
),
176173
coords=dict(
@@ -186,7 +183,7 @@ def on_validation_epoch_start(self):
186183
# Set up stores which we will fill during validation
187184
self.all_val_results: list[xr.Dataset] = []
188185
self._val_horizon_maes: list[np.array] = []
189-
if self.current_epoch==0:
186+
if self.current_epoch == 0:
190187
self._val_persistence_horizon_maes: list[np.array] = []
191188

192189
# Plot some sample forecasts
@@ -197,29 +194,26 @@ def on_validation_epoch_start(self):
197194

198195
for plot_num in range(num_figures):
199196
idxs = np.arange(plots_per_figure) + plot_num * plots_per_figure
200-
idxs = idxs[idxs<len(val_dataset)]
197+
idxs = idxs[idxs < len(val_dataset)]
201198

202-
if len(idxs)==0:
199+
if len(idxs) == 0:
203200
continue
204201

205202
batch = collate_fn([val_dataset[i] for i in idxs])
206203
batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
207204

208205
# Batch validation check only during sanity check phase - use first batch
209206
if self.trainer.sanity_checking and plot_num == 0:
210-
validate_batch_against_config(
211-
batch=batch,
212-
model=self.model
213-
)
214-
207+
validate_batch_against_config(batch=batch, model=self.model)
208+
215209
with torch.no_grad():
216210
y_hat = self.model(batch)
217-
211+
218212
fig = plot_sample_forecasts(
219213
batch,
220214
y_hat,
221215
quantiles=self.model.output_quantiles,
222-
key_to_plot=self.model._target_key,
216+
key_to_plot="generation",
223217
)
224218

225219
plot_name = f"val_forecast_samples/sample_set_{plot_num}"
@@ -238,7 +232,7 @@ def validation_step(self, batch: TensorBatch, batch_idx: int) -> None:
238232
# Internally store the val predictions
239233
self._store_val_predictions(batch, y_hat)
240234

241-
y = batch[self.model._target_key][:, -self.model.forecast_len :]
235+
y = batch["generation"][:, -self.model.forecast_len :]
242236

243237
losses = self._calculate_common_losses(y, y_hat)
244238
losses = {f"{k}/val": v for k, v in losses.items()}
@@ -262,21 +256,22 @@ def validation_step(self, batch: TensorBatch, batch_idx: int) -> None:
262256

263257
# Calculate the persistance losses - we only need to do this once per training run
264258
# not every epoch
265-
if self.current_epoch==0:
259+
if self.current_epoch == 0:
266260
y_persist = (
267-
batch[self.model._target_key][:, -(self.model.forecast_len+1)]
268-
.unsqueeze(1).expand(-1, self.model.forecast_len)
261+
batch["generation"][:, -(self.model.forecast_len + 1)]
262+
.unsqueeze(1)
263+
.expand(-1, self.model.forecast_len)
269264
)
270265
mae_step_persist, mse_step_persist = self._calculate_step_metrics(y, y_persist)
271266
self._val_persistence_horizon_maes.append(mae_step_persist)
272267
losses.update(
273268
{
274-
"MAE/val_persistence": mae_step_persist.mean(),
275-
"MSE/val_persistence": mse_step_persist.mean()
269+
"MAE/val_persistence": mae_step_persist.mean(),
270+
"MSE/val_persistence": mse_step_persist.mean(),
276271
}
277272
)
278273

279-
# Log the metrics
274+
# Log the metrics
280275
self.log_dict(losses, on_step=False, on_epoch=True)
281276

282277
def on_validation_epoch_end(self) -> None:
@@ -289,7 +284,7 @@ def on_validation_epoch_end(self) -> None:
289284
self._val_horizon_maes = []
290285

291286
# We only run this on the first epoch
292-
if self.current_epoch==0:
287+
if self.current_epoch == 0:
293288
val_persistence_horizon_maes = np.mean(self._val_persistence_horizon_maes, axis=0)
294289
self._val_persistence_horizon_maes = []
295290

@@ -321,25 +316,25 @@ def on_validation_epoch_end(self) -> None:
321316
wandb_log_dir = self.logger.experiment.dir
322317
filepath = f"{wandb_log_dir}/validation_results.netcdf"
323318
ds_val_results.to_netcdf(filepath)
324-
325-
# Uplodad to wandb
319+
320+
# Uplodad to wandb
326321
self.logger.experiment.save(filepath, base_path=wandb_log_dir, policy="now")
327-
322+
328323
# Create the horizon accuracy curve
329324
horizon_mae_plot = wandb_line_plot(
330-
x=np.arange(self.model.forecast_len),
325+
x=np.arange(self.model.forecast_len),
331326
y=val_horizon_maes,
332327
xlabel="Horizon step",
333328
ylabel="MAE",
334329
title="Val horizon loss curve",
335330
)
336-
331+
337332
wandb.log({"val_horizon_mae_plot": horizon_mae_plot})
338333

339334
# Create persistence horizon accuracy curve but only on first epoch
340-
if self.current_epoch==0:
335+
if self.current_epoch == 0:
341336
persist_horizon_mae_plot = wandb_line_plot(
342-
x=np.arange(self.model.forecast_len),
337+
x=np.arange(self.model.forecast_len),
343338
y=val_persistence_horizon_maes,
344339
xlabel="Horizon step",
345340
ylabel="MAE",

pvnet/utils.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Utils"""
2+
23
import logging
34
from typing import TYPE_CHECKING
45

@@ -17,7 +18,7 @@
1718
MODEL_CONFIG_NAME = "model_config.yaml"
1819
DATA_CONFIG_NAME = "data_config.yaml"
1920
DATAMODULE_CONFIG_NAME = "datamodule_config.yaml"
20-
FULL_CONFIG_NAME = "full_experiment_config.yaml"
21+
FULL_CONFIG_NAME = "full_experiment_config.yaml"
2122
MODEL_CARD_NAME = "README.md"
2223

2324

@@ -93,37 +94,41 @@ def print_config(
9394

9495

9596
def validate_batch_against_config(
96-
batch: dict,
97+
batch: dict,
9798
model: "BaseModel",
9899
) -> None:
99100
"""Validates tensor shapes in batch against model configuration."""
100101
logger.info("Performing batch shape validation against model config.")
101-
102+
102103
# NWP validation
103-
if hasattr(model, 'nwp_encoders_dict'):
104+
if hasattr(model, "nwp_encoders_dict"):
104105
if "nwp" not in batch:
105106
raise ValueError(
106107
"Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
107108
)
108-
109+
109110
for source, nwp_data in batch["nwp"].items():
110111
if source in model.nwp_encoders_dict:
111-
112-
enc = model.nwp_encoders_dict[source]
112+
enc = model.nwp_encoders_dict[source]
113113
expected_channels = enc.in_channels
114114
if model.add_image_embedding_channel:
115115
expected_channels -= 1
116116

117-
expected = (nwp_data["nwp"].shape[0], enc.sequence_length,
118-
expected_channels, enc.image_size_pixels, enc.image_size_pixels)
117+
expected = (
118+
nwp_data["nwp"].shape[0],
119+
enc.sequence_length,
120+
expected_channels,
121+
enc.image_size_pixels,
122+
enc.image_size_pixels,
123+
)
119124
if tuple(nwp_data["nwp"].shape) != expected:
120-
actual_shape = tuple(nwp_data['nwp'].shape)
125+
actual_shape = tuple(nwp_data["nwp"].shape)
121126
raise ValueError(
122127
f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
123128
)
124129

125130
# Satellite validation
126-
if hasattr(model, 'sat_encoder'):
131+
if hasattr(model, "sat_encoder"):
127132
if "satellite_actual" not in batch:
128133
raise ValueError(
129134
"Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
@@ -134,14 +139,19 @@ def validate_batch_against_config(
134139
if model.add_image_embedding_channel:
135140
expected_channels -= 1
136141

137-
expected = (batch["satellite_actual"].shape[0], enc.sequence_length, expected_channels,
138-
enc.image_size_pixels, enc.image_size_pixels)
142+
expected = (
143+
batch["satellite_actual"].shape[0],
144+
enc.sequence_length,
145+
expected_channels,
146+
enc.image_size_pixels,
147+
enc.image_size_pixels,
148+
)
139149
if tuple(batch["satellite_actual"].shape) != expected:
140-
actual_shape = tuple(batch['satellite_actual'].shape)
150+
actual_shape = tuple(batch["satellite_actual"].shape)
141151
raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
142152

143153
# generation validation
144-
key = model._target_key
154+
key = "generation"
145155
if key in batch:
146156
total_minutes = model.history_minutes + model.forecast_minutes
147157
interval = model.interval_minutes

0 commit comments

Comments
 (0)