Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jacob/windnet' into jacob/windnet
Browse files Browse the repository at this point in the history
# Conflicts:
#	pvnet/models/base_model.py
  • Loading branch information
jacobbieker committed Jan 19, 2024
2 parents f28d77d + ea77b0e commit c14375b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions pvnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,12 @@ def __bool__(self):

# @staticmethod
def _filter_batch_dict(self, d):
keep_keys = (
[BatchKey[self.key_to_keep], BatchKey[f"{self.key_to_keep}_id"], BatchKey[f"{self.key_to_keep}_t0_idx"], BatchKey[f"{self.key_to_keep}_time_utc"]]
)
keep_keys = [
BatchKey[self.key_to_keep],
BatchKey[f"{self.key_to_keep}_id"],
BatchKey[f"{self.key_to_keep}_t0_idx"],
BatchKey[f"{self.key_to_keep}_time_utc"],
]
return {k: v for k, v in d.items() if k in keep_keys}

def append(self, batch: dict[BatchKey, list[torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _get_numpy(key):
y = batch[y_key][:, :, 0].cpu().numpy() # Select the one it is trained on
y_hat = y_hat.cpu().numpy()
gsp_ids = batch[y_id_key][:, 0].cpu().numpy().squeeze()
t0_idx = int(batch[t0_idx_key])
int(batch[t0_idx_key])
plotting_name = key_to_plot.upper()

gsp_ids = batch[y_id_key].cpu().numpy().squeeze()
Expand Down

0 comments on commit c14375b

Please sign in to comment.