Skip to content

Commit 63f3fc0

Browse files
committed
Clean up
1 parent 3f07b32 commit 63f3fc0

File tree

7 files changed

+10
-55
lines changed

7 files changed

+10
-55
lines changed

pvnet/models/multimodal/site_encoders/encoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def __init__(
206206
)
207207

208208
def _encode_inputs(self, x):
209-
# Shape: [batch size, sequence length, number of sites] -> [8, 197, 1]
210-
# Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23]
209+
# Shape: [batch size, sequence length, number of sites]
210+
# Shape: [batch size, station_id, sequence length, channels]
211211
input_data = x[f"{self.input_key_to_use}"]
212212
if len(input_data.shape) == 2: # one site per sample
213213
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D

pvnet/models/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import torch
7-
from ocf_datapipes.batch import BatchKey
87

98
logger = logging.getLogger(__name__)
109

@@ -84,7 +83,7 @@ class BatchAccumulator(DictListAccumulator):
8483
"""A class for accumulating batches when using grad accumulation and the batch size is small.
8584
8685
Attributes:
87-
_batches (Dict[BatchKey, list[torch.Tensor]]): Dictionary containing lists of metrics.
86+
_batches (Dict[str, list[torch.Tensor]]): Dictionary containing lists of metrics.
8887
"""
8988

9089
def __init__(self, key_to_keep: str = "gsp"):
@@ -105,14 +104,14 @@ def _filter_batch_dict(self, d):
105104
]
106105
return {k: v for k, v in d.items() if k in keep_keys}
107106

108-
def append(self, batch: dict[BatchKey, list[torch.Tensor]]):
107+
def append(self, batch: dict[str, list[torch.Tensor]]):
109108
"""Append batch to self"""
110109
if not self:
111110
self._batches = self._dict_init_list(self._filter_batch_dict(batch))
112111
else:
113112
self._dict_list_append(self._batches, self._filter_batch_dict(batch))
114113

115-
def flush(self) -> dict[BatchKey, list[torch.Tensor]]:
114+
def flush(self) -> dict[str, list[torch.Tensor]]:
116115
"""Concatenate all accumulated batches, return, and clear self"""
117116
batch = {}
118117
for k, v in self._batches.items():

pvnet/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def plot_batch_forecasts(
258258
def _get_numpy(key):
259259
return batch[key].cpu().numpy().squeeze()
260260

261-
y_key = f"{key_to_plot}"
261+
y_key = key_to_plot
262262
y_id_key = f"{key_to_plot}_id"
263263
time_utc_key = f"{key_to_plot}_time_utc"
264264
y = batch[y_key].cpu().numpy() # Select the one it is trained on

tests/conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import torch
1010
import hydra
1111

12-
from ocf_datapipes.batch import BatchKey
1312
from datetime import timedelta
1413

1514
import pvnet
@@ -164,7 +163,7 @@ def sample_pv_batch():
164163
def sample_site_batch():
165164
dm = SiteDataModule(
166165
configuration=None,
167-
batch_size=8,
166+
batch_size=2,
168167
num_workers=0,
169168
prefetch_factor=None,
170169
train_period=[None, None],

tests/models/multimodal/site_encoders/test_encoders.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
from ocf_datapipes.batch import BatchKey
32
from torch import nn
43

54
from pvnet.models.multimodal.site_encoders.encoders import (
@@ -38,11 +37,11 @@ def test_singleattentionnetwork_forward(sample_site_batch, site_encoder_model_kw
3837
sample_site_batch,
3938
SingleAttentionNetwork,
4039
site_encoder_model_kwargs_dsampler,
41-
batch_size=8,
40+
batch_size=2,
4241
)
4342

4443

45-
# TODO once we have updated the sample batches for sites include this test
44+
# TODO once we have test data which inludes sensor data with sites include this test
4645
# def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs):
4746
# _test_model_forward(
4847
# sample_wind_batch,

tests/models/multimodal/test_multimodal.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from torch.optim import SGD
22
import pytest
3-
from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey
4-
53

64
def test_model_forward(multimodal_model, sample_batch):
75
y = multimodal_model(sample_batch)

tests/test_data/presaved_samples_site/data_configuration.yaml

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,6 @@ input_data:
1212
dropout_timedeltas_minutes: null
1313
dropout_fraction: 0 # Fraction of samples with dropout
1414

15-
# nwp:
16-
# mo_global:
17-
# provider: mo_global
18-
# zarr_path: /mnt/storage_b/nwp/ceda/global/20*.zarr
19-
# interval_start_minutes: -60
20-
# interval_end_minutes: 1860
21-
# time_resolution_minutes: 60
22-
# channels:
23-
# - temperature_sl
24-
# - wind_u_component_10m
25-
# - wind_v_component_10m
26-
# image_size_pixels_height: 4
27-
# image_size_pixels_width: 4
28-
# # A random value from the list below will be chosen as the delay when dropout is used
29-
# # If set to null no dropout is applied. Values must be negative.
30-
# dropout_timedeltas_minutes: [-180]
31-
# # Dropout applied with this probability
32-
# dropout_fraction: 1.0
33-
# # How long after the NWP init-time are we still willing to use this forecast
34-
# # If null we use each init-time for all steps it covers
35-
# max_staleness_minutes: null
36-
3715
nwp:
3816
ecmwf:
3917
provider: ecmwf
@@ -56,22 +34,4 @@ input_data:
5634
image_size_pixels_width: 24
5735
dropout_timedeltas_minutes: [-360]
5836
dropout_fraction: 1.0
59-
max_staleness_minutes: null
60-
61-
# gfs:
62-
# provider: gfs
63-
# dropout_fraction: 1.0
64-
# dropout_timedeltas_minutes: [-300]
65-
# interval_start_minutes: 0
66-
# interval_end_minutes: 2160
67-
# channels:
68-
# - t
69-
# - prate
70-
# - u10
71-
# - v10
72-
# - u100
73-
# - v100
74-
# image_size_pixels_height: 2
75-
# image_size_pixels_width: 2
76-
# zarr_path: /mnt/storage_b/nwp/gfs/nw_india/gfs_nw_india_*.zarr
77-
# time_resolution_minutes: 180
37+
max_staleness_minutes: null

0 commit comments

Comments
 (0)