File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
pvnet/models/multimodal/site_encoders Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change 44
55import einops
66import torch
7+ from ocf_datapipes .batch import BatchKey
78from torch import nn
89
910from pvnet .models .multimodal .linear_networks .networks import ResFCNet2
@@ -74,13 +75,13 @@ def __init__(
7475 )
7576
7677 def _calculate_attention (self , x ):
77- gsp_ids = x [" gsp_id" ].squeeze ().int ()
78+ gsp_ids = x [BatchKey . gsp_id ].squeeze ().int ()
7879 attention = self ._attention_network (gsp_ids )
7980 return attention
8081
8182 def _encode_value (self , x ):
8283 # Shape: [batch size, sequence length, PV site]
83- pv_site_seqs = x ["pv" ].float ()
84+ pv_site_seqs = x [BatchKey . pv ].float ()
8485 batch_size = pv_site_seqs .shape [0 ]
8586
8687 pv_site_seqs = pv_site_seqs .swapaxes (1 , 2 ).flatten (0 , 1 )
You can’t perform that action at this time.
0 commit comments