Skip to content

Commit 4d69524

Browse files
committed
update tests
1 parent bfdf112 commit 4d69524

27 files changed

+228
-161
lines changed

pvnet/models/base_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def _log_validation_results(self, batch, y_hat, accum_batch_num):
605605
"""Append validation results to self.validation_epoch_results"""
606606

607607
# get truth values, shape (b, forecast_len)
608-
y = batch[self._target_key][:, -self.forecast_len :, 0]
608+
y = batch[self._target_key][:, -self.forecast_len :]
609609
y = y.detach().cpu().numpy()
610610
batch_size = y.shape[0]
611611

pvnet/models/baseline/last_value.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, x: dict):
3636

3737
# take the last value non forecaster value and the first in the pv yeild
3838
# (this is the pv site we are preditcting for)
39-
y_hat = gsp_yield[:, -self.forecast_len - 1, 0]
39+
y_hat = gsp_yield[:, -self.forecast_len - 1]
4040

4141
# expand the last valid forward n predict steps
4242
out = y_hat.unsqueeze(1).repeat(1, self.forecast_len)

pvnet/models/baseline/single_value.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ def __init__(
3333
def forward(self, x: dict):
3434
"""Run model forward on dict batch of data"""
3535
# Returns a single value at all steps
36-
y_hat = torch.zeros_like(x[BatchKey.gsp][:, : self.forecast_len, 0]) + self._value
36+
y_hat = torch.zeros_like(x[BatchKey.gsp][:, : self.forecast_len]) + self._value
3737
return y_hat

pvnet/models/multimodal/multimodal.py

-10
Original file line numberDiff line numberDiff line change
@@ -377,16 +377,6 @@ def forward(self, x):
377377
# This needs to be a Batch as input
378378
modes["wind"] = self.wind_encoder(x_tmp)
379379

380-
# *********************** Sensor Data ************************************
381-
if self.include_sensor:
382-
if self._target_key_name != "sensor":
383-
modes["sensor"] = self.sensor_encoder(x)
384-
else:
385-
x_tmp = x.copy()
386-
x_tmp[BatchKey.sensor] = x_tmp[BatchKey.sensor][:, : self.history_len + 1]
387-
# This needs to be a Batch as input
388-
modes["sensor"] = self.sensor_encoder(x_tmp)
389-
390380
if self.include_sun:
391381
sun = torch.cat(
392382
(

pvnet/models/multimodal/unimodal_teacher.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def teacher_forward(self, x):
219219
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
220220

221221
if self.add_image_embedding_channel:
222-
id = x[BatchKey.gsp_id][:, 0].int()
222+
id = x[BatchKey.gsp_id].int()
223223
sat_data = teacher_model.sat_embed(sat_data, id)
224224

225225
modes[mode] = teacher_model.sat_encoder(sat_data)
@@ -233,7 +233,7 @@ def teacher_forward(self, x):
233233
nwp_data = torch.swapaxes(nwp_data, 1, 2) # switch time and channels
234234
nwp_data = torch.clip(nwp_data, min=-50, max=50)
235235
if teacher_model.add_image_embedding_channel:
236-
id = x[BatchKey.gsp_id][:, 0].int()
236+
id = x[BatchKey.gsp_id].int()
237237
nwp_data = teacher_model.nwp_embed_dict[nwp_source](nwp_data, id)
238238

239239
nwp_out = teacher_model.nwp_encoders_dict[nwp_source](nwp_data)
@@ -260,7 +260,7 @@ def forward(self, x, return_modes=False):
260260
sat_data = torch.swapaxes(sat_data, 1, 2).float() # switch time and channels
261261

262262
if self.add_image_embedding_channel:
263-
id = x[BatchKey.gsp_id][:, 0].int()
263+
id = x[BatchKey.gsp_id].int()
264264
sat_data = self.sat_embed(sat_data, id)
265265
modes["sat"] = self.sat_encoder(sat_data)
266266

@@ -276,7 +276,7 @@ def forward(self, x, return_modes=False):
276276
nwp_data = torch.clip(nwp_data, min=-50, max=50)
277277

278278
if self.add_image_embedding_channel:
279-
id = x[BatchKey.gsp_id][:, 0].int()
279+
id = x[BatchKey.gsp_id].int()
280280
nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id)
281281

282282
nwp_out = self.nwp_encoders_dict[nwp_source](nwp_data)
@@ -301,7 +301,7 @@ def forward(self, x, return_modes=False):
301301

302302
# ********************** Embedding of GSP ID ********************
303303
if self.embedding_dim:
304-
id = x[BatchKey.gsp_id][:, 0].int()
304+
id = x[BatchKey.gsp_id].int()
305305
id_embedding = self.embed(id)
306306
modes["id"] = id_embedding
307307

tests/conftest.py

+17-35
Original file line numberDiff line numberDiff line change
@@ -106,40 +106,37 @@ def sample_train_val_datamodule():
106106

107107
file_n = 0
108108

109-
for file in glob.glob("tests/test_data/sample_batches/train/*.pt"):
110-
batch = torch.load(file)
109+
for file_n, file in enumerate(glob.glob("tests/test_data/presaved_samples/train/*.pt")):
110+
sample = torch.load(file)
111111

112112
for i in range(n_duplicates):
113113
# Save fopr both train and val
114-
torch.save(batch, f"{tmpdirname}/train/{file_n:06}.pt")
115-
torch.save(batch, f"{tmpdirname}/val/{file_n:06}.pt")
116-
117-
file_n += 1
114+
torch.save(sample, f"{tmpdirname}/train/{file_n:06}.pt")
115+
torch.save(sample, f"{tmpdirname}/val/{file_n:06}.pt")
118116

119117
dm = DataModule(
120118
configuration=None,
119+
sample_dir=f"{tmpdirname}",
121120
batch_size=2,
122121
num_workers=0,
123122
prefetch_factor=None,
124123
train_period=[None, None],
125124
val_period=[None, None],
126-
test_period=[None, None],
127-
batch_dir=f"{tmpdirname}",
125+
128126
)
129127
yield dm
130128

131129

132130
@pytest.fixture()
133131
def sample_datamodule():
134132
dm = DataModule(
133+
sample_dir="tests/test_data/presaved_samples",
135134
configuration=None,
136135
batch_size=2,
137136
num_workers=0,
138137
prefetch_factor=None,
139138
train_period=[None, None],
140139
val_period=[None, None],
141-
test_period=[None, None],
142-
batch_dir="tests/test_data/sample_batches",
143140
)
144141
return dm
145142

@@ -157,9 +154,10 @@ def sample_satellite_batch(sample_batch):
157154

158155

159156
@pytest.fixture()
160-
def sample_pv_batch(sample_batch):
161-
pv_data = sample_batch[BatchKey.pv]
162-
return pv_data
157+
def sample_pv_batch():
158+
# TODO: Once PV site inputs are available from ocf-data-sampler UK regional remove these
159+
# old batches. For now we use the old batches to test the site encoder models
160+
return torch.load("tests/test_data/presaved_batches/train/000000.pt")
163161

164162

165163
@pytest.fixture()
@@ -191,7 +189,7 @@ def model_minutes_kwargs():
191189
def encoder_model_kwargs():
192190
# Used to test encoder model on satellite data
193191
kwargs = dict(
194-
sequence_length=(90 - 30) // 5 + 1,
192+
sequence_length=7, # 30 minutes of 5 minutely satellite data = 7 time steps
195193
image_size_pixels=24,
196194
in_channels=11,
197195
out_features=128,
@@ -240,23 +238,16 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs):
240238
"ukv": dict(
241239
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
242240
_partial_=True,
243-
in_channels=2,
241+
in_channels=11,
244242
out_features=128,
245243
number_of_conv3d_layers=6,
246244
conv3d_channels=32,
247245
image_size_pixels=24,
248246
),
249247
},
250248
add_image_embedding_channel=True,
251-
pv_encoder=dict(
252-
_target_=pvnet.models.multimodal.site_encoders.encoders.SingleAttentionNetwork,
253-
_partial_=True,
254-
num_sites=349,
255-
out_features=40,
256-
num_heads=4,
257-
kdim=40,
258-
id_embed_dim=20,
259-
),
249+
# ocf-data-sampler doesn't supprt PV site inputs yet
250+
pv_encoder=None,
260251
output_network=dict(
261252
_target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2,
262253
_partial_=True,
@@ -268,11 +259,10 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs):
268259
embedding_dim=16,
269260
include_sun=True,
270261
include_gsp_yield_history=True,
271-
sat_history_minutes=90,
262+
sat_history_minutes=30,
272263
nwp_history_minutes={"ukv": 120},
273264
nwp_forecast_minutes={"ukv": 480},
274-
pv_history_minutes=180,
275-
min_sat_delay_minutes=30,
265+
min_sat_delay_minutes=0,
276266
)
277267

278268
kwargs.update(model_minutes_kwargs)
@@ -297,14 +287,6 @@ def multimodal_quantile_model(multimodal_model_kwargs):
297287
return model
298288

299289

300-
@pytest.fixture()
301-
def multimodal_weighted_quantile_model(multimodal_model_kwargs):
302-
model = Model(
303-
output_quantiles=[0.1, 0.5, 0.9], **multimodal_model_kwargs, use_weighted_loss=True
304-
)
305-
return model
306-
307-
308290
@pytest.fixture()
309291
def multimodal_quantile_model_ignore_minutes(multimodal_model_kwargs):
310292
"""Only forecsat second half of the 8 hours"""

tests/data/test_datamodule.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from pvnet.data.datamodule import DataModule
23
from pvnet.data.wind_datamodule import WindDataModule
34
from pvnet.data.pv_site_datamodule import PVSiteDataModule
@@ -8,16 +9,16 @@
89
def test_init():
910
dm = DataModule(
1011
configuration=None,
12+
sample_dir="tests/test_data/presaved_samples",
1113
batch_size=2,
1214
num_workers=0,
1315
prefetch_factor=None,
1416
train_period=[None, None],
1517
val_period=[None, None],
16-
test_period=[None, None],
17-
batch_dir="tests/test_data/sample_batches",
1818
)
1919

2020

21+
@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet")
2122
def test_wind_init():
2223
dm = WindDataModule(
2324
configuration=None,
@@ -30,7 +31,7 @@ def test_wind_init():
3031
batch_dir="tests/data/sample_batches",
3132
)
3233

33-
34+
@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet")
3435
def test_wind_init_with_nwp_filter():
3536
dm = WindDataModule(
3637
configuration=None,
@@ -53,6 +54,7 @@ def test_wind_init_with_nwp_filter():
5354
assert batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp].shape[2] == 2
5455

5556

57+
@pytest.mark.skip(reason="Has not been updated for ocf-data-sampler yet")
5658
def test_pv_site_init():
5759
dm = PVSiteDataModule(
5860
configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_batches/data_configuration.yaml",
@@ -69,13 +71,12 @@ def test_pv_site_init():
6971
def test_iter():
7072
dm = DataModule(
7173
configuration=None,
74+
sample_dir="tests/test_data/presaved_samples",
7275
batch_size=2,
7376
num_workers=0,
7477
prefetch_factor=None,
7578
train_period=[None, None],
7679
val_period=[None, None],
77-
test_period=[None, None],
78-
batch_dir="tests/test_data/sample_batches",
7980
)
8081

8182
batch = next(iter(dm.train_dataloader()))
@@ -84,15 +85,21 @@ def test_iter():
8485
def test_iter_multiprocessing():
8586
dm = DataModule(
8687
configuration=None,
87-
batch_size=2,
88+
sample_dir="tests/test_data/presaved_samples",
89+
batch_size=1,
8890
num_workers=2,
89-
prefetch_factor=2,
91+
prefetch_factor=1,
9092
train_period=[None, None],
9193
val_period=[None, None],
92-
test_period=[None, None],
93-
batch_dir="tests/test_data/sample_batches",
9494
)
9595

96-
batch = next(iter(dm.train_dataloader()))
96+
served_batches = 0
9797
for batch in dm.train_dataloader():
98-
pass
98+
served_batches += 1
99+
100+
# Stop once we've got 2 batches
101+
if served_batches==2:
102+
break
103+
104+
# Make sure we've served 2 batches
105+
assert served_batches == 2

tests/models/baseline/test_last_value.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def last_value_model(model_minutes_kwargs):
99

1010

1111
def test_model_forward(last_value_model, sample_batch):
12+
1213
y = last_value_model(sample_batch)
1314

1415
# check output is the correct shape

tests/models/multimodal/site_encoders/test_encoders.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
import pytest
1111

1212

13-
def _test_model_forward(batch, model_class, kwargs):
13+
def _test_model_forward(batch, model_class, kwargs, batch_size):
1414
model = model_class(**kwargs)
1515
y = model(batch)
16-
assert tuple(y.shape) == (2, kwargs["out_features"]), y.shape
16+
assert tuple(y.shape) == (batch_size, kwargs["out_features"]), y.shape
1717

1818

1919
def _test_model_backward(batch, model_class, kwargs):
@@ -24,22 +24,37 @@ def _test_model_backward(batch, model_class, kwargs):
2424

2525

2626
# Test model forward on all models
27-
def test_simplelearnedaggregator_forward(sample_batch, site_encoder_model_kwargs):
28-
_test_model_forward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs)
27+
def test_simplelearnedaggregator_forward(sample_pv_batch, site_encoder_model_kwargs):
28+
_test_model_forward(
29+
sample_pv_batch,
30+
SimpleLearnedAggregator,
31+
site_encoder_model_kwargs,
32+
batch_size=8,
33+
)
2934

3035

31-
def test_singleattentionnetwork_forward(sample_batch, site_encoder_model_kwargs):
32-
_test_model_forward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs)
36+
def test_singleattentionnetwork_forward(sample_pv_batch, site_encoder_model_kwargs):
37+
_test_model_forward(
38+
sample_pv_batch,
39+
SingleAttentionNetwork,
40+
site_encoder_model_kwargs,
41+
batch_size=8,
42+
)
3343

3444

3545
def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs):
36-
_test_model_forward(sample_wind_batch, SingleAttentionNetwork, site_encoder_sensor_model_kwargs)
46+
_test_model_forward(
47+
sample_wind_batch,
48+
SingleAttentionNetwork,
49+
site_encoder_sensor_model_kwargs,
50+
batch_size=2,
51+
)
3752

3853

3954
# Test model backward on all models
40-
def test_simplelearnedaggregator_backward(sample_batch, site_encoder_model_kwargs):
41-
_test_model_backward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs)
55+
def test_simplelearnedaggregator_backward(sample_pv_batch, site_encoder_model_kwargs):
56+
_test_model_backward(sample_pv_batch, SimpleLearnedAggregator, site_encoder_model_kwargs)
4257

4358

44-
def test_singleattentionnetwork_backward(sample_batch, site_encoder_model_kwargs):
45-
_test_model_backward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs)
59+
def test_singleattentionnetwork_backward(sample_pv_batch, site_encoder_model_kwargs):
60+
_test_model_backward(sample_pv_batch, SingleAttentionNetwork, site_encoder_model_kwargs)

0 commit comments

Comments
 (0)