Skip to content

Commit 9cfc8a9

Browse files
committed
Make target key more general
1 parent 9b91283 commit 9cfc8a9

File tree

3 files changed

+14
-20
lines changed

3 files changed

+14
-20
lines changed

pvnet/models/base_model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239
forecast_minutes: int,
240240
optimizer: AbstractOptimizer,
241241
output_quantiles: Optional[list[float]] = None,
242-
target_key: BatchKey = BatchKey.gsp,
242+
target_key: str = "gsp",
243243
):
244244
"""Abtstract base class for PVNet submodels.
245245
@@ -254,7 +254,8 @@ def __init__(
254254
super().__init__()
255255

256256
self._optimizer = optimizer
257-
self._target_key = target_key
257+
self._target_key_name = target_key
258+
self._target_key = BatchKey[f"{self._target_key}"]
258259

259260
# Model must have lr to allow tuning
260261
# This setting is only used when lr is tuned with callback
@@ -275,7 +276,7 @@ def __init__(
275276

276277
self._accumulated_metrics = MetricAccumulator()
277278
self._accumulated_batches = BatchAccumulator(
278-
key_to_keep="gsp" if self._target_key == BatchKey.gsp else "sensor"
279+
key_to_keep=self._target_key
279280
)
280281
self._accumulated_y_hat = PredAccumulator()
281282

@@ -438,7 +439,7 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
438439
y_hat,
439440
batch_idx,
440441
quantiles=self.output_quantiles,
441-
key_to_plot="gsp" if self._target_key == BatchKey.gsp else "wind",
442+
key_to_plot=self._target_key_name,
442443
)
443444
fig.savefig("latest_logged_train_batch.png")
444445

@@ -503,7 +504,7 @@ def validation_step(self, batch: dict, batch_idx):
503504
if not hasattr(self, "_val_y_hats"):
504505
self._val_y_hats = PredAccumulator()
505506
self._val_batches = BatchAccumulator(
506-
key_to_keep="gsp" if self._target_key == BatchKey.gsp else "sensor"
507+
key_to_keep=self._target_key_name
507508
)
508509

509510
self._val_y_hats.append(y_hat)
@@ -517,7 +518,7 @@ def validation_step(self, batch: dict, batch_idx):
517518
batch,
518519
y_hat,
519520
quantiles=self.output_quantiles,
520-
key_to_plot="gsp" if self._target_key == BatchKey.gsp else "sensor",
521+
key_to_plot=self._target_key_name,
521522
)
522523

523524
self.logger.experiment.log(

pvnet/models/utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,7 @@ def __bool__(self):
100100
# @staticmethod
101101
def _filter_batch_dict(self, d):
102102
keep_keys = (
103-
[BatchKey.gsp, BatchKey.gsp_id, BatchKey.gsp_t0_idx, BatchKey.gsp_time_utc]
104-
if self.key_to_keep == "gsp"
105-
else [
106-
BatchKey.wind,
107-
BatchKey.wind_id,
108-
BatchKey.wind_t0_idx,
109-
BatchKey.wind_time_utc,
110-
]
103+
[BatchKey[self.key_to_keep], BatchKey[f"{self.key_to_keep}_id"], BatchKey[f"{self.key_to_keep}_t0_idx"], BatchKey[f"{self.key_to_keep}_time_utc"]]
111104
)
112105
return {k: v for k, v in d.items() if k in keep_keys}
113106

@@ -122,7 +115,7 @@ def flush(self) -> dict[BatchKey, list[torch.Tensor]]:
122115
"""Concatenate all accumulated batches, return, and clear self"""
123116
batch = {}
124117
for k, v in self._batches.items():
125-
if k == BatchKey.gsp_t0_idx or k == BatchKey.wind_t0_idx:
118+
if k == BatchKey.gsp_t0_idx or k == BatchKey.wind_t0_idx or k == BatchKey.pv_t0_idx:
126119
batch[k] = v[0]
127120
else:
128121
batch[k] = torch.cat(v, dim=0)

pvnet/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -248,15 +248,15 @@ def plot_batch_forecasts(batch, y_hat, batch_idx=None, quantiles=None, key_to_pl
248248
def _get_numpy(key):
249249
return batch[key].cpu().numpy().squeeze()
250250

251-
y_key = BatchKey.gsp if key_to_plot == "gsp" else BatchKey.wind
252-
y_id_key = BatchKey.gsp_id if key_to_plot == "gsp" else BatchKey.wind_id
253-
t0_idx_key = BatchKey.gsp_t0_idx if key_to_plot == "gsp" else BatchKey.wind_t0_idx
254-
time_utc_key = BatchKey.gsp_time_utc if key_to_plot == "gsp" else BatchKey.wind_time_utc
251+
y_key = BatchKey[f"{key_to_plot}"]
252+
y_id_key = BatchKey[f"{key_to_plot}_id"]
253+
t0_idx_key = BatchKey[f"{key_to_plot}_t0_idx"]
254+
time_utc_key = BatchKey[f"{key_to_plot}_time_utc"]
255255
y = batch[y_key][:, :, 0].cpu().numpy() # Select the one it is trained on
256256
y_hat = y_hat.cpu().numpy()
257257
gsp_ids = batch[y_id_key][:, 0].cpu().numpy().squeeze()
258258
t0_idx = int(batch[t0_idx_key])
259-
plotting_name = "GSP" if key_to_plot == "gsp" else "Wind"
259+
plotting_name = key_to_plot.upper()
260260

261261
gsp_ids = batch[y_id_key].cpu().numpy().squeeze()
262262
batch[t0_idx_key]

0 commit comments

Comments
 (0)