Skip to content

Commit 7e74993

Browse files
committed
Model testing update
1 parent af54862 commit 7e74993

File tree

1 file changed

+350
-0
lines changed

1 file changed

+350
-0
lines changed
Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
# test_multimodal_dynamic.py
2+
3+
4+
""" Testing for dynamic fusion multimodal model definition """
5+
6+
7+
import pytest
8+
import torch
9+
import torch.nn as nn
10+
11+
from omegaconf import DictConfig
12+
from ocf_datapipes.batch import BatchKey, NWPBatchKey
13+
from torch.optim import SGD
14+
15+
from pvnet.models.multimodal.multimodal_dynamic import Model
16+
from pvnet.models.multimodal.linear_networks.output_networks import DynamicOutputNetwork
17+
18+
19+
class MockNWPEncoder(nn.Module):
20+
""" Simplified mock encoder - explicit dimension handling """
21+
22+
def __init__(self, in_channels=4, image_size_pixels=224):
23+
super().__init__()
24+
self.keywords = {"in_channels": in_channels}
25+
self.image_size_pixels = image_size_pixels
26+
self.hidden_dim = 256
27+
28+
# Generate exact feature size needed
29+
self.features = nn.Parameter(torch.randn(self.hidden_dim))
30+
31+
def forward(self, x):
32+
33+
batch_size = x.size(0)
34+
return self.features.unsqueeze(0).expand(batch_size, -1)
35+
36+
37+
# Basic model as fixture - definition
38+
@pytest.fixture
39+
def basic_model():
40+
nwp_encoders_dict = {"mock_nwp": MockNWPEncoder()}
41+
nwp_forecast_minutes = DictConfig({"mock_nwp": 60})
42+
nwp_history_minutes = DictConfig({"mock_nwp": 60})
43+
44+
model = Model(
45+
output_network=DynamicOutputNetwork,
46+
nwp_encoders_dict=nwp_encoders_dict,
47+
pv_encoder=None,
48+
wind_encoder=None,
49+
sensor_encoder=None,
50+
add_image_embedding_channel=False,
51+
include_gsp_yield_history=False,
52+
include_sun=False,
53+
include_time=False,
54+
embedding_dim=None,
55+
fusion_hidden_dim=256,
56+
num_fusion_heads=8,
57+
fusion_dropout=0.1,
58+
fusion_method="weighted_sum",
59+
forecast_minutes=30,
60+
history_minutes=60,
61+
nwp_forecast_minutes=nwp_forecast_minutes,
62+
nwp_history_minutes=nwp_history_minutes,
63+
)
64+
65+
return model
66+
67+
68+
def test_model_forward_pass(basic_model):
69+
""" Standard forward pass test """
70+
71+
batch_size = 4
72+
sequence_length = basic_model.history_len
73+
height = width = 224
74+
channels = 4
75+
76+
mock_nwp_data = torch.randn(batch_size, sequence_length, channels, height, width)
77+
batch = {
78+
BatchKey.nwp: {
79+
"mock_nwp": {
80+
NWPBatchKey.nwp: mock_nwp_data
81+
}
82+
}
83+
}
84+
85+
with torch.no_grad():
86+
encoded_nwp = basic_model.nwp_encoders_dict["mock_nwp"](mock_nwp_data)
87+
print(f"Encoded NWP shape: {encoded_nwp.shape}")
88+
89+
output, encoded_features = basic_model(batch)
90+
91+
# Assert - check dimensions with forward pass
92+
assert output.shape == (batch_size, basic_model.num_output_features)
93+
assert isinstance(encoded_features, torch.Tensor)
94+
assert encoded_features.shape == (batch_size, basic_model.fusion_hidden_dim)
95+
96+
97+
def test_model_init_minimal():
98+
""" Minimal initialisation of model test """
99+
100+
nwp_encoders_dict = {"mock_nwp": MockNWPEncoder()}
101+
nwp_forecast_minutes = DictConfig({"mock_nwp": 60})
102+
nwp_history_minutes = DictConfig({"mock_nwp": 60})
103+
104+
model = Model(
105+
output_network=DynamicOutputNetwork,
106+
nwp_encoders_dict=nwp_encoders_dict,
107+
pv_encoder=None,
108+
wind_encoder=None,
109+
sensor_encoder=None,
110+
add_image_embedding_channel=False,
111+
include_gsp_yield_history=False,
112+
include_sun=False,
113+
include_time=False,
114+
embedding_dim=None,
115+
fusion_hidden_dim=256,
116+
num_fusion_heads=8,
117+
fusion_dropout=0.1,
118+
fusion_method="weighted_sum",
119+
forecast_minutes=30,
120+
history_minutes=60,
121+
nwp_forecast_minutes=nwp_forecast_minutes,
122+
nwp_history_minutes=nwp_history_minutes,
123+
)
124+
125+
assert isinstance(model, nn.Module)
126+
assert model.include_nwp
127+
assert not model.include_pv
128+
assert not model.include_wind
129+
assert not model.include_sensor
130+
assert not model.include_sun
131+
assert not model.include_time
132+
assert not model.include_gsp_yield_history
133+
134+
assert isinstance(model.nwp_encoders_dict, dict)
135+
assert "mock_nwp" in model.nwp_encoders_dict
136+
137+
assert isinstance(model.encoder, nn.Module)
138+
assert isinstance(model.output_network, nn.Module)
139+
140+
141+
def test_model_quantile_regression(basic_model):
142+
""" Test model with quantile regression config """
143+
144+
# Create model with quantile regression
145+
quantile_model = Model(
146+
output_network=DynamicOutputNetwork,
147+
output_quantiles=[0.1, 0.5, 0.9],
148+
nwp_encoders_dict={"mock_nwp": MockNWPEncoder()},
149+
nwp_forecast_minutes=DictConfig({"mock_nwp": 60}),
150+
nwp_history_minutes=DictConfig({"mock_nwp": 60}),
151+
pv_encoder=None,
152+
wind_encoder=None,
153+
sensor_encoder=None,
154+
add_image_embedding_channel=False,
155+
include_gsp_yield_history=False,
156+
include_sun=False,
157+
include_time=False,
158+
embedding_dim=None,
159+
fusion_hidden_dim=256,
160+
num_fusion_heads=8,
161+
fusion_dropout=0.1,
162+
fusion_method="weighted_sum",
163+
forecast_minutes=30,
164+
history_minutes=60
165+
)
166+
167+
batch_size = 4
168+
sequence_length = quantile_model.history_len
169+
height = width = 224
170+
channels = 4
171+
172+
mock_nwp_data = torch.randn(batch_size, sequence_length, channels, height, width)
173+
batch = {
174+
BatchKey.nwp: {
175+
"mock_nwp": {
176+
NWPBatchKey.nwp: mock_nwp_data
177+
}
178+
}
179+
}
180+
181+
with torch.no_grad():
182+
output, encoded_features = quantile_model(batch)
183+
184+
# Verify output shape and type are correct when using multiple quantiles
185+
assert quantile_model.use_quantile_regression
186+
assert len(quantile_model.output_quantiles) == 3
187+
assert output.shape == (batch_size, quantile_model.forecast_len, len(quantile_model.output_quantiles))
188+
assert torch.isfinite(output).all()
189+
190+
# Random init variation check
191+
quantile_variances = output.std(dim=2)
192+
assert (quantile_variances > 0).any(), "Quantile predictions should show some variation"
193+
194+
195+
196+
def test_model_partial_inputs_and_error_handling(basic_model):
197+
""" Check error handling / robustness of model """
198+
199+
batch_size = 4
200+
sequence_length = basic_model.history_len
201+
height = width = 224
202+
channels = 4
203+
204+
# Minimal valid input
205+
minimal_batch = {
206+
BatchKey.nwp: {
207+
"mock_nwp": {
208+
NWPBatchKey.nwp: torch.randn(batch_size, sequence_length, channels, height, width)
209+
}
210+
}
211+
}
212+
213+
with torch.no_grad():
214+
output, encoded_features = basic_model(minimal_batch)
215+
216+
assert output.shape == (batch_size, basic_model.num_output_features)
217+
assert encoded_features.shape == (batch_size, basic_model.fusion_hidden_dim)
218+
assert torch.isfinite(output).all()
219+
220+
# Missing NWP data
221+
empty_nwp_batch = {
222+
BatchKey.nwp: {}
223+
}
224+
225+
with pytest.raises(Exception):
226+
with torch.no_grad():
227+
_ = basic_model(empty_nwp_batch)
228+
229+
# None input for NWP
230+
none_nwp_batch = {
231+
BatchKey.nwp: {
232+
"mock_nwp": {
233+
NWPBatchKey.nwp: None
234+
}
235+
}
236+
}
237+
238+
with pytest.raises(Exception):
239+
with torch.no_grad():
240+
_ = basic_model(none_nwp_batch)
241+
242+
# Empty input dict
243+
empty_batch = {}
244+
245+
with pytest.raises(Exception):
246+
with torch.no_grad():
247+
_ = basic_model(empty_batch)
248+
249+
# Verify model can handle variations in input
250+
varied_sequence_batch = {
251+
BatchKey.nwp: {
252+
"mock_nwp": {
253+
NWPBatchKey.nwp: torch.randn(batch_size, max(1, sequence_length - 1), channels, height, width)
254+
}
255+
}
256+
}
257+
258+
try:
259+
with torch.no_grad():
260+
result, _ = basic_model(varied_sequence_batch)
261+
except Exception as e:
262+
assert "input" in str(e).lower() or "shape" in str(e).lower()
263+
264+
265+
def test_model_backward(basic_model):
266+
""" Test backward pass functionality - backprop verify """
267+
268+
batch_size = 4
269+
sequence_length = basic_model.history_len
270+
height = width = 224
271+
channels = 4
272+
273+
# Prepare input batch
274+
batch = {
275+
BatchKey.nwp: {
276+
"mock_nwp": {
277+
NWPBatchKey.nwp: torch.randn(batch_size, sequence_length, channels, height, width)
278+
}
279+
}
280+
}
281+
282+
optimizer = SGD(basic_model.parameters(), lr=0.001)
283+
output, _ = basic_model(batch)
284+
285+
# Backward pass
286+
optimizer.zero_grad()
287+
output.sum().backward()
288+
289+
# Check gradients are not None
290+
for name, param in basic_model.named_parameters():
291+
if param.requires_grad:
292+
assert param.grad is not None, f"Gradient for {name} is None"
293+
294+
295+
def test_quantile_model_backward(basic_model):
296+
""" Test backward pass functionality - backprop verify - quantile regression """
297+
298+
# Create model with quantile regression
299+
quantile_model = Model(
300+
output_network=DynamicOutputNetwork,
301+
output_quantiles=[0.1, 0.5, 0.9],
302+
nwp_encoders_dict={"mock_nwp": MockNWPEncoder()},
303+
nwp_forecast_minutes=DictConfig({"mock_nwp": 60}),
304+
nwp_history_minutes=DictConfig({"mock_nwp": 60}),
305+
pv_encoder=None,
306+
wind_encoder=None,
307+
sensor_encoder=None,
308+
add_image_embedding_channel=False,
309+
include_gsp_yield_history=False,
310+
include_sun=False,
311+
include_time=False,
312+
embedding_dim=None,
313+
fusion_hidden_dim=256,
314+
num_fusion_heads=8,
315+
fusion_dropout=0.1,
316+
fusion_method="weighted_sum",
317+
forecast_minutes=30,
318+
history_minutes=60
319+
)
320+
321+
batch_size = 4
322+
sequence_length = quantile_model.history_len
323+
height = width = 224
324+
channels = 4
325+
326+
# Prepare input batch
327+
batch = {
328+
BatchKey.nwp: {
329+
"mock_nwp": {
330+
NWPBatchKey.nwp: torch.randn(batch_size, sequence_length, channels, height, width)
331+
}
332+
}
333+
}
334+
335+
optimizer = SGD(quantile_model.parameters(), lr=0.001)
336+
output, _ = quantile_model(batch)
337+
338+
# Backward pass
339+
optimizer.zero_grad()
340+
output.sum().backward()
341+
342+
# Check quantile regression specific properties
343+
assert quantile_model.use_quantile_regression
344+
assert len(quantile_model.output_quantiles) == 3
345+
assert output.shape == (batch_size, quantile_model.forecast_len, len(quantile_model.output_quantiles))
346+
347+
# Check gradients are not None
348+
for name, param in quantile_model.named_parameters():
349+
if param.requires_grad:
350+
assert param.grad is not None, f"Gradient for {name} is None"

0 commit comments

Comments
 (0)