Skip to content

Commit 87d5718

Browse files
committed
tidy
1 parent 292ad9d commit 87d5718

File tree

3 files changed

+5
-23
lines changed

3 files changed

+5
-23
lines changed

pvnet/data/datamodule.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,6 @@
1212
)
1313

1414

15-
def fill_nans_in_arrays(batch):
16-
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.
17-
18-
Operation is performed in-place on the batch.
19-
"""
20-
for k, v in batch.items():
21-
if isinstance(v, torch.Tensor):
22-
if torch.isnan(v).any():
23-
batch[k] = torch.nan_to_num(v, nan=0.0)
24-
25-
# Recursion is included to reach NWP arrays in subdict
26-
elif isinstance(v, dict):
27-
fill_nans_in_arrays(v)
28-
29-
return batch
30-
31-
32-
3315
class NumpybatchPremadeSamplesDataset(Dataset):
3416
"""Dataset to load NumpyBatch samples"""
3517

@@ -46,7 +28,7 @@ def __len__(self):
4628
return len(self.sample_paths)
4729

4830
def __getitem__(self, idx):
49-
return fill_nans_in_arrays(torch.load(self.sample_paths[idx]))
31+
return torch.load(self.sample_paths[idx])
5032

5133

5234
def collate_fn(samples: list[NumpyBatch]):

pvnet/models/multimodal/multimodal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def forward(self, x):
318318
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
319319

320320
if self.add_image_embedding_channel:
321-
id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int()
321+
id = x[BatchKey[f"{self._target_key_name}_id"]].int()
322322
sat_data = self.sat_embed(sat_data, id)
323323
modes["sat"] = self.sat_encoder(sat_data)
324324

@@ -335,7 +335,7 @@ def forward(self, x):
335335
nwp_data = torch.clip(nwp_data, min=-50, max=50)
336336

337337
if self.add_image_embedding_channel:
338-
id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int()
338+
id = x[BatchKey[f"{self._target_key_name}_id"]].int()
339339
nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id)
340340

341341
nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
@@ -362,7 +362,7 @@ def forward(self, x):
362362

363363
# ********************** Embedding of GSP ID ********************
364364
if self.embedding_dim:
365-
id = x[BatchKey[f"{self._target_key_name}_id"]][:].int()
365+
id = x[BatchKey[f"{self._target_key_name}_id"]].int()
366366
id_embedding = self.embed(id)
367367
modes["id"] = id_embedding
368368

pvnet/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _get_numpy(key):
265265
y_id_key = BatchKey[f"{key_to_plot}_id"]
266266
BatchKey[f"{key_to_plot}_t0_idx"]
267267
time_utc_key = BatchKey[f"{key_to_plot}_time_utc"]
268-
y = batch[y_key][:, :].cpu().numpy() # Select the one it is trained on
268+
y = batch[y_key].cpu().numpy() # Select the one it is trained on
269269
y_hat = y_hat.cpu().numpy()
270270
# Select between the timesteps in timesteps to plot
271271
plotting_name = key_to_plot.upper()

0 commit comments

Comments
 (0)