File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
pvnet/models/multimodal/site_encoders Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change 4
4
5
5
import einops
6
6
import torch
7
+ from ocf_datapipes .batch import BatchKey
7
8
from torch import nn
8
9
9
10
from pvnet .models .multimodal .linear_networks .networks import ResFCNet2
@@ -74,13 +75,13 @@ def __init__(
74
75
)
75
76
76
77
def _calculate_attention (self , x ):
77
- gsp_ids = x [" gsp_id" ].squeeze ().int ()
78
+ gsp_ids = x [BatchKey . gsp_id ].squeeze ().int ()
78
79
attention = self ._attention_network (gsp_ids )
79
80
return attention
80
81
81
82
def _encode_value (self , x ):
82
83
# Shape: [batch size, sequence length, PV site]
83
- pv_site_seqs = x ["pv" ].float ()
84
+ pv_site_seqs = x [BatchKey . pv ].float ()
84
85
batch_size = pv_site_seqs .shape [0 ]
85
86
86
87
pv_site_seqs = pv_site_seqs .swapaxes (1 , 2 ).flatten (0 , 1 )
You can’t perform that action at this time.
0 commit comments