Skip to content

Commit c14375b

Browse files
committed
Merge remote-tracking branch 'origin/jacob/windnet' into jacob/windnet
# Conflicts: # pvnet/models/base_model.py
2 parents f28d77d + ea77b0e commit c14375b

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

pvnet/models/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,12 @@ def __bool__(self):
9999

100100
# @staticmethod
101101
def _filter_batch_dict(self, d):
102-
keep_keys = (
103-
[BatchKey[self.key_to_keep], BatchKey[f"{self.key_to_keep}_id"], BatchKey[f"{self.key_to_keep}_t0_idx"], BatchKey[f"{self.key_to_keep}_time_utc"]]
104-
)
102+
keep_keys = [
103+
BatchKey[self.key_to_keep],
104+
BatchKey[f"{self.key_to_keep}_id"],
105+
BatchKey[f"{self.key_to_keep}_t0_idx"],
106+
BatchKey[f"{self.key_to_keep}_time_utc"],
107+
]
105108
return {k: v for k, v in d.items() if k in keep_keys}
106109

107110
def append(self, batch: dict[BatchKey, list[torch.Tensor]]):

pvnet/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _get_numpy(key):
255255
y = batch[y_key][:, :, 0].cpu().numpy() # Select the one it is trained on
256256
y_hat = y_hat.cpu().numpy()
257257
gsp_ids = batch[y_id_key][:, 0].cpu().numpy().squeeze()
258-
t0_idx = int(batch[t0_idx_key])
258+
int(batch[t0_idx_key])
259259
plotting_name = key_to_plot.upper()
260260

261261
gsp_ids = batch[y_id_key].cpu().numpy().squeeze()

0 commit comments

Comments
 (0)