We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ae1d46b commit eb2fbe5Copy full SHA for eb2fbe5
pvnet/models/multimodal/multimodal.py
@@ -342,9 +342,16 @@ def forward(self, x):
342
modes["sensor"] = self.sensor_encoder(x_tmp)
343
344
if self.include_sun:
345
- sun = torch.cat(
346
- (x[BatchKey.gsp_solar_azimuth], x[BatchKey.gsp_solar_elevation]), dim=1
347
- ).float()
+ if self.target_key_name == "gsp":
+ sun = torch.cat(
+ (x[BatchKey.gsp_solar_azimuth], x[BatchKey.gsp_solar_elevation]), dim=1
348
+ ).float()
349
+ elif self.target_key_name == "pv":
350
351
+ (x[BatchKey.pv_solar_azimuth], x[BatchKey.pv_solar_elevation]), dim=1
352
353
+ else:
354
+ raise ValueError(f"Unknown target key for sun elevation {self.target_key_name}")
355
sun = self.sun_fc1(sun)
356
modes["sun"] = sun
357
0 commit comments