diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index 06176145..76ec276b 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -4,7 +4,9 @@ import einops import torch -from ocf_datapipes.batch import BatchKey + +# removed BatchKey import +# from ocf_datapipes.batch import BatchKey from torch import nn from pvnet.models.multimodal.linear_networks.networks import ResFCNet2 @@ -75,13 +77,13 @@ def __init__( ) def _calculate_attention(self, x): - gsp_ids = x[BatchKey.gsp_id].squeeze().int() + gsp_ids = x["gsp_id"].squeeze().int() # removed BatchKey attention = self._attention_network(gsp_ids) return attention def _encode_value(self, x): # Shape: [batch size, sequence length, PV site] - pv_site_seqs = x[BatchKey.pv].float() + pv_site_seqs = x["pv"].float() # removed BatchKey batch_size = pv_site_seqs.shape[0] pv_site_seqs = pv_site_seqs.swapaxes(1, 2).flatten(0, 1)