diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index c17cccfe..11ac5a1d 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -260,8 +260,8 @@ def forward(self, x): modes["pv"] = self.pv_encoder(x) else: # Target is PV, so only take the history - pv_history = x[BatchKey.pv][:, : self.history_len_30].float() - modes["pv"] = self.pv_encoder(pv_history) + x[BatchKey.pv] = x[BatchKey.pv][:, : self.history_len_30] + modes["pv"] = self.pv_encoder(x) # *********************** GSP Data ************************************ # add gsp yield history @@ -282,9 +282,10 @@ def forward(self, x): if self.target_key_name != "wind": modes["wind"] = self.wind_encoder(x) else: - # Target is wind, so only take the history - wind_history = x[BatchKey.wind][:, : self.history_len_30].float() - modes["wind"] = self.wind_encoder(wind_history) + # Have to be its own Batch format + x[BatchKey.wind] = x[BatchKey.wind][:, : self.history_len_30] + # This needs to be a Batch as input + modes["wind"] = self.wind_encoder(x) if self.include_sun: sun = torch.cat(