Skip to content

Commit

Permalink
Fix encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Jan 24, 2024
1 parent 0d8d349 commit 08a72e5
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ def forward(self, x):
modes["pv"] = self.pv_encoder(x)
else:
# Target is PV, so only take the history
pv_history = x[BatchKey.pv][:, : self.history_len_30].float()
modes["pv"] = self.pv_encoder(pv_history)
x[BatchKey.pv] = x[BatchKey.pv][:, : self.history_len_30]
modes["pv"] = self.pv_encoder(x)

# *********************** GSP Data ************************************
# add gsp yield history
Expand All @@ -282,9 +282,10 @@ def forward(self, x):
if self.target_key_name != "wind":
modes["wind"] = self.wind_encoder(x)
else:
# Target is wind, so only take the history
wind_history = x[BatchKey.wind][:, : self.history_len_30].float()
modes["wind"] = self.wind_encoder(wind_history)
# Have to be its own Batch format
x[BatchKey.wind] = x[BatchKey.wind][:, : self.history_len_30]
# This needs to be a Batch as input
modes["wind"] = self.wind_encoder(x)

if self.include_sun:
sun = torch.cat(
Expand Down

0 comments on commit 08a72e5

Please sign in to comment.