Skip to content

Commit 595c1df

Browse files
committed
Change inputs
1 parent f93f39c commit 595c1df

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

pvnet/models/multimodal/multimodal.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,10 @@ def forward(self, x):
260260
modes["pv"] = self.pv_encoder(x)
261261
else:
262262
# Target is PV, so only take the history
263-
x[BatchKey.pv] = x[BatchKey.pv][:, : self.history_len_30]
264-
modes["pv"] = self.pv_encoder(x)
263+
# Copy batch
264+
x_tmp = x.copy()
265+
x_tmp[BatchKey.pv] = x_tmp[BatchKey.pv][:, : self.history_len_30]
266+
modes["pv"] = self.pv_encoder(x_tmp)
265267

266268
# *********************** GSP Data ************************************
267269
# add gsp yield history
@@ -283,9 +285,10 @@ def forward(self, x):
283285
modes["wind"] = self.wind_encoder(x)
284286
else:
285287
# Have to be its own Batch format
286-
x[BatchKey.wind] = x[BatchKey.wind][:, : self.history_len_30]
288+
x_tmp = x.copy()
289+
x_tmp[BatchKey.wind] = x_tmp[BatchKey.wind][:, : self.history_len_30]
287290
# This needs to be a Batch as input
288-
modes["wind"] = self.wind_encoder(x)
291+
modes["wind"] = self.wind_encoder(x_tmp)
289292

290293
if self.include_sun:
291294
sun = torch.cat(

0 commit comments

Comments
 (0)