From 9f16d1e5a9e76692065d7d292199002970e7b48d Mon Sep 17 00:00:00 2001 From: devojyotimisra Date: Sat, 1 Mar 2025 11:26:44 +0530 Subject: [PATCH 1/2] removed 'BatchKey' from 'encoders.py' and replaced necessary parts with strings --- pvnet/models/multimodal/site_encoders/encoders.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 06176145..7ed5d188 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -4,7 +4,10 @@ import einops import torch -from ocf_datapipes.batch import BatchKey + +# removed BatchKey import +# from ocf_datapipes.batch import BatchKey + from torch import nn from pvnet.models.multimodal.linear_networks.networks import ResFCNet2 @@ -75,13 +78,13 @@ def __init__( ) def _calculate_attention(self, x): - gsp_ids = x[BatchKey.gsp_id].squeeze().int() + gsp_ids = x["gsp_id"].squeeze().int() # removed BatchKey attention = self._attention_network(gsp_ids) return attention def _encode_value(self, x): # Shape: [batch size, sequence length, PV site] - pv_site_seqs = x[BatchKey.pv].float() + pv_site_seqs = x["pv"].float() # removed BatchKey batch_size = pv_site_seqs.shape[0] pv_site_seqs = pv_site_seqs.swapaxes(1, 2).flatten(0, 1) From 19fa7686c9b381c7caf11e4b57fddd679fc864d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 1 Mar 2025 06:11:00 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/models/multimodal/site_encoders/encoders.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 7ed5d188..76ec276b 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -7,7 +7,6 @@ # removed BatchKey import # from ocf_datapipes.batch import BatchKey - from torch import nn from pvnet.models.multimodal.linear_networks.networks import ResFCNet2 @@ -78,13 +77,13 @@ def __init__( ) def _calculate_attention(self, x): - gsp_ids = x["gsp_id"].squeeze().int() # removed BatchKey + gsp_ids = x["gsp_id"].squeeze().int() # removed BatchKey attention = self._attention_network(gsp_ids) return attention def _encode_value(self, x): # Shape: [batch size, sequence length, PV site] - pv_site_seqs = x["pv"].float() # removed BatchKey + pv_site_seqs = x["pv"].float() # removed BatchKey batch_size = pv_site_seqs.shape[0] pv_site_seqs = pv_site_seqs.swapaxes(1, 2).flatten(0, 1)