Skip to content

Commit b283ac4

Browse files
committed
Simplify
1 parent eb2fbe5 commit b283ac4

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

pvnet/models/multimodal/multimodal.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -342,16 +342,9 @@ def forward(self, x):
342342
modes["sensor"] = self.sensor_encoder(x_tmp)
343343

344344
if self.include_sun:
345-
if self.target_key_name == "gsp":
346-
sun = torch.cat(
347-
(x[BatchKey.gsp_solar_azimuth], x[BatchKey.gsp_solar_elevation]), dim=1
348-
).float()
349-
elif self.target_key_name == "pv":
350-
sun = torch.cat(
351-
(x[BatchKey.pv_solar_azimuth], x[BatchKey.pv_solar_elevation]), dim=1
352-
).float()
353-
else:
354-
raise ValueError(f"Unknown target key for sun elevation {self.target_key_name}")
345+
sun = torch.cat(
346+
(x[BatchKey[f"{self.target_key_name}_solar_azimuth"]], x[BatchKey[f"{self.target_key_name}_solar_elevation"]]), dim=1
347+
).float()
355348
sun = self.sun_fc1(sun)
356349
modes["sun"] = sun
357350

0 commit comments

Comments
 (0)