Skip to content

Commit 3b8d59d

Browse files
committed
Tests for dynamic encoder
1 parent 38768f4 commit 3b8d59d

File tree

1 file changed

+278
-30
lines changed

1 file changed

+278
-30
lines changed
Lines changed: 278 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1+
# test_dynamic_encoder.py
2+
3+
4+
""" Testing for dynamic fusion encoder components """
5+
6+
17
import pytest
28
import torch
39
from typing import Dict
410

511
from pvnet.models.multimodal.encoders.dynamic_encoder import DynamicFusionEncoder
612

13+
14+
# Fixtures
715
@pytest.fixture
816
def minimal_config():
9-
"""Minimal configuration for testing basic functionality"""
17+
""" Generate minimal config - basic functionality testing """
1018
sequence_length = 12
11-
hidden_dim = 60 # Chosen so it divides evenly by sequence_length (60/12 = 5)
12-
13-
# Important: feature_dim needs to match between modalities
14-
feature_dim = hidden_dim // sequence_length # This is 5
19+
hidden_dim = 48
20+
feature_dim = hidden_dim // sequence_length
1521

1622
return {
1723
'sequence_length': sequence_length,
@@ -26,47 +32,49 @@ def minimal_config():
2632
'modality_encoders': {
2733
'sat': {
2834
'image_size_pixels': 24,
29-
'out_features': feature_dim * sequence_length, # 60
35+
'out_features': hidden_dim,
3036
'number_of_conv3d_layers': 2,
3137
'conv3d_channels': 16,
3238
'batch_norm': True,
3339
'fc_dropout': 0.1
3440
},
3541
'pv': {
3642
'num_sites': 10,
37-
'out_features': feature_dim # 5 - this ensures proper dimension
43+
'out_features': feature_dim
3844
}
3945
},
4046
'cross_attention': {
41-
'embed_dim': hidden_dim,
47+
'embed_dim': feature_dim,
4248
'num_heads': 4,
4349
'dropout': 0.1,
4450
'num_modalities': 2
4551
},
4652
'modality_gating': {
4753
'feature_dims': {
48-
'sat': hidden_dim,
49-
'pv': hidden_dim
54+
'sat': feature_dim,
55+
'pv': feature_dim
5056
},
51-
'hidden_dim': hidden_dim,
57+
'hidden_dim': feature_dim, # Changed to feature_dim
5258
'dropout': 0.1
5359
},
5460
'dynamic_fusion': {
5561
'feature_dims': {
56-
'sat': hidden_dim,
57-
'pv': hidden_dim
62+
'sat': feature_dim,
63+
'pv': feature_dim
5864
},
59-
'hidden_dim': hidden_dim,
65+
'hidden_dim': feature_dim, # Changed to feature_dim
6066
'num_heads': 4,
6167
'dropout': 0.1,
6268
'fusion_method': 'weighted_sum',
6369
'use_residual': True
6470
}
6571
}
6672

73+
6774
@pytest.fixture
6875
def minimal_inputs(minimal_config):
69-
"""Generate minimal test inputs"""
76+
""" Generate minimal inputs with expected tensor shapes """
77+
7078
batch_size = 2
7179
sequence_length = minimal_config['sequence_length']
7280

@@ -75,22 +83,106 @@ def minimal_inputs(minimal_config):
7583
'pv': torch.randn(batch_size, sequence_length, 10)
7684
}
7785

78-
def test_batch_sizes(self, minimal_config, minimal_inputs, batch_size):
79-
"""Test different batch sizes"""
80-
encoder = DynamicFusionEncoder(
81-
sequence_length=minimal_config['sequence_length'],
82-
image_size_pixels=minimal_config['image_size_pixels'],
83-
modality_channels=minimal_config['modality_channels'],
84-
out_features=minimal_config['out_features'],
85-
modality_encoders=minimal_config['modality_encoders'],
86-
cross_attention=minimal_config['cross_attention'],
87-
modality_gating=minimal_config['modality_gating'],
88-
dynamic_fusion=minimal_config['dynamic_fusion'],
89-
hidden_dim=minimal_config['hidden_dim'],
90-
fc_features=minimal_config['fc_features']
86+
87+
def create_encoder(config):
88+
""" Helper function - create encoder with consistent config """
89+
90+
return DynamicFusionEncoder(
91+
sequence_length=config['sequence_length'],
92+
image_size_pixels=config['image_size_pixels'],
93+
modality_channels=config['modality_channels'],
94+
out_features=config['out_features'],
95+
modality_encoders=config['modality_encoders'],
96+
cross_attention=config['cross_attention'],
97+
modality_gating=config['modality_gating'],
98+
dynamic_fusion=config['dynamic_fusion'],
99+
hidden_dim=config['hidden_dim'],
100+
fc_features=config['fc_features']
91101
)
102+
103+
104+
def test_initialisation(minimal_config):
105+
""" Verify encoder initialisation / component structure """
106+
107+
encoder = create_encoder(minimal_config)
108+
assert isinstance(encoder, DynamicFusionEncoder)
109+
assert len(encoder.modality_encoders) == 2
110+
assert 'sat' in encoder.modality_encoders
111+
assert 'pv' in encoder.modality_encoders
112+
113+
114+
def test_basic_forward(minimal_config, minimal_inputs):
115+
""" Test basic forward pass shape and values """
116+
117+
encoder = create_encoder(minimal_config)
118+
119+
with torch.no_grad():
120+
output = encoder(minimal_inputs)
121+
122+
assert output.shape == (2, minimal_config['out_features'])
123+
assert not torch.isnan(output).any()
124+
assert output.dtype == torch.float32
125+
126+
127+
# Modality handling tests
128+
def test_single_modality(minimal_config, minimal_inputs):
129+
""" Test forward pass with single modality """
130+
encoder = create_encoder(minimal_config)
131+
132+
# Test with only satellite data
133+
with torch.no_grad():
134+
sat_only = {'sat': minimal_inputs['sat']}
135+
output_sat = encoder(sat_only)
136+
137+
assert output_sat.shape == (2, minimal_config['out_features'])
138+
assert not torch.isnan(output_sat).any()
139+
140+
# Test with only PV data
141+
with torch.no_grad():
142+
pv_only = {'pv': minimal_inputs['pv']}
143+
output_pv = encoder(pv_only)
144+
145+
assert output_pv.shape == (2, minimal_config['out_features'])
146+
assert not torch.isnan(output_pv).any()
147+
148+
149+
def test_intermediate_shapes(minimal_config, minimal_inputs):
150+
""" Verify shapes of intermediate tensors throughout network """
151+
152+
encoder = create_encoder(minimal_config)
153+
batch_size = minimal_inputs['sat'].size(0)
154+
sequence_length = minimal_config['sequence_length']
155+
feature_dim = minimal_config['hidden_dim'] // sequence_length
156+
157+
def hook_fn(module, input, output):
158+
if isinstance(output, dict):
159+
for key, value in output.items():
160+
assert len(value.shape) == 3 # [batch, sequence, features]
161+
assert value.size(0) == batch_size
162+
assert value.size(1) == sequence_length
163+
assert value.size(2) == feature_dim
164+
elif isinstance(output, torch.Tensor):
165+
if len(output.shape) == 3:
166+
assert output.size(0) == batch_size
167+
assert output.size(1) == sequence_length
168+
169+
# Register hooks
170+
if hasattr(encoder, 'gating'):
171+
encoder.gating.register_forward_hook(hook_fn)
172+
if hasattr(encoder, 'cross_attention'):
173+
encoder.cross_attention.register_forward_hook(hook_fn)
174+
175+
with torch.no_grad():
176+
encoder(minimal_inputs)
177+
178+
179+
# Robustness tests
180+
@pytest.mark.parametrize("batch_size", [1, 4])
181+
def test_batch_sizes(minimal_config, minimal_inputs, batch_size):
182+
""" Test encoder behavior with different batch sizes """
183+
encoder = create_encoder(minimal_config)
92184

93-
# Adjust input batch sizes - fixed repeat logic
185+
# Adjust input batch sizes
94186
adjusted_inputs = {}
95187
for k, v in minimal_inputs.items():
96188
if batch_size < v.size(0):
@@ -103,4 +195,160 @@ def test_batch_sizes(self, minimal_config, minimal_inputs, batch_size):
103195
output = encoder(adjusted_inputs)
104196

105197
assert output.shape == (batch_size, minimal_config['out_features'])
106-
assert not torch.isnan(output).any()
198+
assert not torch.isnan(output).any()
199+
200+
201+
# Error handling tests
202+
def test_empty_input(minimal_config):
203+
""" Verify error handling for empty input dictionary """
204+
205+
encoder = create_encoder(minimal_config)
206+
with pytest.raises(ValueError, match="No valid features after encoding"):
207+
encoder({})
208+
209+
210+
def test_invalid_modality(minimal_config, minimal_inputs):
211+
""" Verify error handling for invalid modality name """
212+
213+
encoder = create_encoder(minimal_config)
214+
invalid_inputs = {'invalid_modality': minimal_inputs['sat']}
215+
with pytest.raises(ValueError):
216+
encoder(invalid_inputs)
217+
218+
219+
def test_none_inputs(minimal_config, minimal_inputs):
220+
""" Test handling of None inputs for modalities """
221+
222+
encoder = create_encoder(minimal_config)
223+
none_inputs = {'sat': None, 'pv': minimal_inputs['pv']}
224+
output = encoder(none_inputs)
225+
assert output.shape == (2, minimal_config['out_features'])
226+
227+
228+
# Config tests
229+
@pytest.mark.parametrize("sequence_length", [6, 24])
230+
def test_variable_sequence_length(minimal_config, sequence_length):
231+
"""Test different sequence lengths"""
232+
config = minimal_config.copy()
233+
config['sequence_length'] = sequence_length
234+
config['hidden_dim'] = sequence_length * 5
235+
236+
encoder = create_encoder(config)
237+
batch_size = 2
238+
inputs = {
239+
'sat': torch.randn(batch_size, 2, sequence_length, 24, 24),
240+
'pv': torch.randn(batch_size, sequence_length, 10)
241+
}
242+
243+
output = encoder(inputs)
244+
assert output.shape == (batch_size, config['out_features'])
245+
246+
247+
# Architecture tests
248+
def test_architecture_components(minimal_config):
249+
"""Test specific architectural components and their connections"""
250+
251+
encoder = create_encoder(minimal_config)
252+
253+
# Test encoder layers
254+
assert hasattr(encoder, 'modality_encoders')
255+
assert hasattr(encoder, 'feature_projections')
256+
assert hasattr(encoder, 'fusion_module')
257+
assert hasattr(encoder, 'final_block')
258+
259+
# Verify encoder has correct number of modalities
260+
assert len(encoder.modality_encoders) == len(minimal_config['modality_channels'])
261+
262+
263+
def test_tensor_shape_tracking(minimal_config, minimal_inputs):
264+
""" Track tensor shapes through network layers """
265+
266+
encoder = create_encoder(minimal_config)
267+
shapes = {}
268+
269+
def hook_fn(name):
270+
def hook(module, input, output):
271+
shapes[name] = output.shape if isinstance(output, torch.Tensor) else \
272+
{k: v.shape for k, v in output.items()}
273+
return hook
274+
275+
# Register shape tracking hooks
276+
encoder.modality_encoders['sat'].register_forward_hook(hook_fn('sat_encoder'))
277+
encoder.feature_projections['sat'].register_forward_hook(hook_fn('sat_projection'))
278+
encoder.fusion_module.register_forward_hook(hook_fn('fusion'))
279+
280+
with torch.no_grad():
281+
output = encoder(minimal_inputs)
282+
283+
# Verify expected shapes
284+
assert shapes['fusion'][1] == encoder.feature_dim
285+
assert output.shape[1] == minimal_config['out_features']
286+
287+
288+
def test_modality_interactions(minimal_config, minimal_inputs):
289+
""" Test interaction between different modality combinations """
290+
291+
encoder = create_encoder(minimal_config)
292+
batch_size = 2
293+
294+
# Test different modality combinations
295+
test_cases = [
296+
({'sat': minimal_inputs['sat']}, "single_sat"),
297+
({'pv': minimal_inputs['pv']}, "single_pv"),
298+
(minimal_inputs, "both")
299+
]
300+
301+
outputs = {}
302+
for inputs, case_name in test_cases:
303+
with torch.no_grad():
304+
outputs[case_name] = encoder(inputs)
305+
306+
# Verify outputs differ across modality combinations
307+
assert not torch.allclose(outputs['single_sat'], outputs['both'])
308+
assert not torch.allclose(outputs['single_pv'], outputs['both'])
309+
310+
311+
def test_attention_behavior(minimal_config, minimal_inputs):
312+
""" Verify attention mechanism properties """
313+
314+
encoder = create_encoder(minimal_config)
315+
attention_outputs = {}
316+
317+
def attention_hook(module, input, output):
318+
if isinstance(output, dict):
319+
attention_outputs.update(output)
320+
321+
if encoder.use_cross_attention:
322+
encoder.cross_attention.register_forward_hook(attention_hook)
323+
324+
with torch.no_grad():
325+
encoder(minimal_inputs)
326+
327+
if attention_outputs:
328+
# Verify attention weight distribution
329+
for modality, features in attention_outputs.items():
330+
std = features.std()
331+
assert std > 1e-6, "Attention weights too uniform"
332+
333+
334+
@pytest.mark.parametrize("noise_level", [0.1, 0.5, 1.0])
335+
def test_input_noise_robustness(minimal_config, minimal_inputs, noise_level):
336+
""" Test encoder stability under different noise levels """
337+
338+
encoder = create_encoder(minimal_config)
339+
340+
# Add controlled noise to inputs
341+
noisy_inputs = {
342+
k: v + noise_level * torch.randn_like(v)
343+
for k, v in minimal_inputs.items()
344+
}
345+
346+
with torch.no_grad():
347+
clean_output = encoder(minimal_inputs)
348+
noisy_output = encoder(noisy_inputs)
349+
350+
# Check output stability
351+
relative_diff = (clean_output - noisy_output).abs().mean() / clean_output.abs().mean()
352+
assert not torch.isnan(relative_diff)
353+
assert not torch.isinf(relative_diff)
354+
assert relative_diff < noise_level * 10

0 commit comments

Comments
 (0)