Skip to content

Commit a0937ff

Browse files
removed BtachKey and made necessary keys as type string
1 parent e9837bb commit a0937ff

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

pvnet/models/multimodal/site_encoders/encoders.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import einops
66
import torch
7-
from ocf_datapipes.batch import BatchKey
87
from torch import nn
98

109
from pvnet.models.multimodal.linear_networks.networks import ResFCNet2
@@ -75,13 +74,15 @@ def __init__(
7574
)
7675

7776
def _calculate_attention(self, x):
78-
gsp_ids = x[BatchKey.gsp_id].squeeze().int()
77+
gsp_ids = x["BatchKey.gsp_id"].squeeze().int()
7978
attention = self._attention_network(gsp_ids)
8079
return attention
8180

8281
def _encode_value(self, x):
8382
# Shape: [batch size, sequence length, PV site]
84-
pv_site_seqs = x[BatchKey.pv].float()
83+
for key in x.keys():
84+
print(key)
85+
pv_site_seqs = x["BatchKey.pv"].float()
8586
batch_size = pv_site_seqs.shape[0]
8687

8788
pv_site_seqs = pv_site_seqs.swapaxes(1, 2).flatten(0, 1)

0 commit comments

Comments
 (0)