1
+ # test_dynamic_encoder.py
2
+
3
+
4
+ """ Testing for dynamic fusion encoder components """
5
+
6
+
1
7
import pytest
2
8
import torch
3
9
from typing import Dict
4
10
5
11
from pvnet .models .multimodal .encoders .dynamic_encoder import DynamicFusionEncoder
6
12
13
+
14
+ # Fixtures
7
15
@pytest .fixture
8
16
def minimal_config ():
9
- """Minimal configuration for testing basic functionality"""
17
+ """ Generate minimal config - basic functionality testing """
10
18
sequence_length = 12
11
- hidden_dim = 60 # Chosen so it divides evenly by sequence_length (60/12 = 5)
12
-
13
- # Important: feature_dim needs to match between modalities
14
- feature_dim = hidden_dim // sequence_length # This is 5
19
+ hidden_dim = 48
20
+ feature_dim = hidden_dim // sequence_length
15
21
16
22
return {
17
23
'sequence_length' : sequence_length ,
@@ -26,47 +32,49 @@ def minimal_config():
26
32
'modality_encoders' : {
27
33
'sat' : {
28
34
'image_size_pixels' : 24 ,
29
- 'out_features' : feature_dim * sequence_length , # 60
35
+ 'out_features' : hidden_dim ,
30
36
'number_of_conv3d_layers' : 2 ,
31
37
'conv3d_channels' : 16 ,
32
38
'batch_norm' : True ,
33
39
'fc_dropout' : 0.1
34
40
},
35
41
'pv' : {
36
42
'num_sites' : 10 ,
37
- 'out_features' : feature_dim # 5 - this ensures proper dimension
43
+ 'out_features' : feature_dim
38
44
}
39
45
},
40
46
'cross_attention' : {
41
- 'embed_dim' : hidden_dim ,
47
+ 'embed_dim' : feature_dim ,
42
48
'num_heads' : 4 ,
43
49
'dropout' : 0.1 ,
44
50
'num_modalities' : 2
45
51
},
46
52
'modality_gating' : {
47
53
'feature_dims' : {
48
- 'sat' : hidden_dim ,
49
- 'pv' : hidden_dim
54
+ 'sat' : feature_dim ,
55
+ 'pv' : feature_dim
50
56
},
51
- 'hidden_dim' : hidden_dim ,
57
+ 'hidden_dim' : feature_dim , # Changed to feature_dim
52
58
'dropout' : 0.1
53
59
},
54
60
'dynamic_fusion' : {
55
61
'feature_dims' : {
56
- 'sat' : hidden_dim ,
57
- 'pv' : hidden_dim
62
+ 'sat' : feature_dim ,
63
+ 'pv' : feature_dim
58
64
},
59
- 'hidden_dim' : hidden_dim ,
65
+ 'hidden_dim' : feature_dim , # Changed to feature_dim
60
66
'num_heads' : 4 ,
61
67
'dropout' : 0.1 ,
62
68
'fusion_method' : 'weighted_sum' ,
63
69
'use_residual' : True
64
70
}
65
71
}
66
72
73
+
67
74
@pytest .fixture
68
75
def minimal_inputs (minimal_config ):
69
- """Generate minimal test inputs"""
76
+ """ Generate minimal inputs with expected tensor shapes """
77
+
70
78
batch_size = 2
71
79
sequence_length = minimal_config ['sequence_length' ]
72
80
@@ -75,22 +83,106 @@ def minimal_inputs(minimal_config):
75
83
'pv' : torch .randn (batch_size , sequence_length , 10 )
76
84
}
77
85
78
- def test_batch_sizes (self , minimal_config , minimal_inputs , batch_size ):
79
- """Test different batch sizes"""
80
- encoder = DynamicFusionEncoder (
81
- sequence_length = minimal_config ['sequence_length' ],
82
- image_size_pixels = minimal_config ['image_size_pixels' ],
83
- modality_channels = minimal_config ['modality_channels' ],
84
- out_features = minimal_config ['out_features' ],
85
- modality_encoders = minimal_config ['modality_encoders' ],
86
- cross_attention = minimal_config ['cross_attention' ],
87
- modality_gating = minimal_config ['modality_gating' ],
88
- dynamic_fusion = minimal_config ['dynamic_fusion' ],
89
- hidden_dim = minimal_config ['hidden_dim' ],
90
- fc_features = minimal_config ['fc_features' ]
86
+
87
+ def create_encoder (config ):
88
+ """ Helper function - create encoder with consistent config """
89
+
90
+ return DynamicFusionEncoder (
91
+ sequence_length = config ['sequence_length' ],
92
+ image_size_pixels = config ['image_size_pixels' ],
93
+ modality_channels = config ['modality_channels' ],
94
+ out_features = config ['out_features' ],
95
+ modality_encoders = config ['modality_encoders' ],
96
+ cross_attention = config ['cross_attention' ],
97
+ modality_gating = config ['modality_gating' ],
98
+ dynamic_fusion = config ['dynamic_fusion' ],
99
+ hidden_dim = config ['hidden_dim' ],
100
+ fc_features = config ['fc_features' ]
91
101
)
102
+
103
+
104
+ def test_initialisation (minimal_config ):
105
+ """ Verify encoder initialisation / component structure """
106
+
107
+ encoder = create_encoder (minimal_config )
108
+ assert isinstance (encoder , DynamicFusionEncoder )
109
+ assert len (encoder .modality_encoders ) == 2
110
+ assert 'sat' in encoder .modality_encoders
111
+ assert 'pv' in encoder .modality_encoders
112
+
113
+
114
+ def test_basic_forward (minimal_config , minimal_inputs ):
115
+ """ Test basic forward pass shape and values """
116
+
117
+ encoder = create_encoder (minimal_config )
118
+
119
+ with torch .no_grad ():
120
+ output = encoder (minimal_inputs )
121
+
122
+ assert output .shape == (2 , minimal_config ['out_features' ])
123
+ assert not torch .isnan (output ).any ()
124
+ assert output .dtype == torch .float32
125
+
126
+
127
+ # Modality handling tests
128
+ def test_single_modality (minimal_config , minimal_inputs ):
129
+ """ Test forward pass with single modality """
130
+ encoder = create_encoder (minimal_config )
131
+
132
+ # Test with only satellite data
133
+ with torch .no_grad ():
134
+ sat_only = {'sat' : minimal_inputs ['sat' ]}
135
+ output_sat = encoder (sat_only )
136
+
137
+ assert output_sat .shape == (2 , minimal_config ['out_features' ])
138
+ assert not torch .isnan (output_sat ).any ()
139
+
140
+ # Test with only PV data
141
+ with torch .no_grad ():
142
+ pv_only = {'pv' : minimal_inputs ['pv' ]}
143
+ output_pv = encoder (pv_only )
144
+
145
+ assert output_pv .shape == (2 , minimal_config ['out_features' ])
146
+ assert not torch .isnan (output_pv ).any ()
147
+
148
+
149
+ def test_intermediate_shapes (minimal_config , minimal_inputs ):
150
+ """ Verify shapes of intermediate tensors throughout network """
151
+
152
+ encoder = create_encoder (minimal_config )
153
+ batch_size = minimal_inputs ['sat' ].size (0 )
154
+ sequence_length = minimal_config ['sequence_length' ]
155
+ feature_dim = minimal_config ['hidden_dim' ] // sequence_length
156
+
157
+ def hook_fn (module , input , output ):
158
+ if isinstance (output , dict ):
159
+ for key , value in output .items ():
160
+ assert len (value .shape ) == 3 # [batch, sequence, features]
161
+ assert value .size (0 ) == batch_size
162
+ assert value .size (1 ) == sequence_length
163
+ assert value .size (2 ) == feature_dim
164
+ elif isinstance (output , torch .Tensor ):
165
+ if len (output .shape ) == 3 :
166
+ assert output .size (0 ) == batch_size
167
+ assert output .size (1 ) == sequence_length
168
+
169
+ # Register hooks
170
+ if hasattr (encoder , 'gating' ):
171
+ encoder .gating .register_forward_hook (hook_fn )
172
+ if hasattr (encoder , 'cross_attention' ):
173
+ encoder .cross_attention .register_forward_hook (hook_fn )
174
+
175
+ with torch .no_grad ():
176
+ encoder (minimal_inputs )
177
+
178
+
179
+ # Robustness tests
180
+ @pytest .mark .parametrize ("batch_size" , [1 , 4 ])
181
+ def test_batch_sizes (minimal_config , minimal_inputs , batch_size ):
182
+ """ Test encoder behavior with different batch sizes """
183
+ encoder = create_encoder (minimal_config )
92
184
93
- # Adjust input batch sizes - fixed repeat logic
185
+ # Adjust input batch sizes
94
186
adjusted_inputs = {}
95
187
for k , v in minimal_inputs .items ():
96
188
if batch_size < v .size (0 ):
@@ -103,4 +195,160 @@ def test_batch_sizes(self, minimal_config, minimal_inputs, batch_size):
103
195
output = encoder (adjusted_inputs )
104
196
105
197
assert output .shape == (batch_size , minimal_config ['out_features' ])
106
- assert not torch .isnan (output ).any ()
198
+ assert not torch .isnan (output ).any ()
199
+
200
+
201
+ # Error handling tests
202
+ def test_empty_input (minimal_config ):
203
+ """ Verify error handling for empty input dictionary """
204
+
205
+ encoder = create_encoder (minimal_config )
206
+ with pytest .raises (ValueError , match = "No valid features after encoding" ):
207
+ encoder ({})
208
+
209
+
210
+ def test_invalid_modality (minimal_config , minimal_inputs ):
211
+ """ Verify error handling for invalid modality name """
212
+
213
+ encoder = create_encoder (minimal_config )
214
+ invalid_inputs = {'invalid_modality' : minimal_inputs ['sat' ]}
215
+ with pytest .raises (ValueError ):
216
+ encoder (invalid_inputs )
217
+
218
+
219
+ def test_none_inputs (minimal_config , minimal_inputs ):
220
+ """ Test handling of None inputs for modalities """
221
+
222
+ encoder = create_encoder (minimal_config )
223
+ none_inputs = {'sat' : None , 'pv' : minimal_inputs ['pv' ]}
224
+ output = encoder (none_inputs )
225
+ assert output .shape == (2 , minimal_config ['out_features' ])
226
+
227
+
228
+ # Config tests
229
+ @pytest .mark .parametrize ("sequence_length" , [6 , 24 ])
230
+ def test_variable_sequence_length (minimal_config , sequence_length ):
231
+ """Test different sequence lengths"""
232
+ config = minimal_config .copy ()
233
+ config ['sequence_length' ] = sequence_length
234
+ config ['hidden_dim' ] = sequence_length * 5
235
+
236
+ encoder = create_encoder (config )
237
+ batch_size = 2
238
+ inputs = {
239
+ 'sat' : torch .randn (batch_size , 2 , sequence_length , 24 , 24 ),
240
+ 'pv' : torch .randn (batch_size , sequence_length , 10 )
241
+ }
242
+
243
+ output = encoder (inputs )
244
+ assert output .shape == (batch_size , config ['out_features' ])
245
+
246
+
247
+ # Architecture tests
248
+ def test_architecture_components (minimal_config ):
249
+ """Test specific architectural components and their connections"""
250
+
251
+ encoder = create_encoder (minimal_config )
252
+
253
+ # Test encoder layers
254
+ assert hasattr (encoder , 'modality_encoders' )
255
+ assert hasattr (encoder , 'feature_projections' )
256
+ assert hasattr (encoder , 'fusion_module' )
257
+ assert hasattr (encoder , 'final_block' )
258
+
259
+ # Verify encoder has correct number of modalities
260
+ assert len (encoder .modality_encoders ) == len (minimal_config ['modality_channels' ])
261
+
262
+
263
+ def test_tensor_shape_tracking (minimal_config , minimal_inputs ):
264
+ """ Track tensor shapes through network layers """
265
+
266
+ encoder = create_encoder (minimal_config )
267
+ shapes = {}
268
+
269
+ def hook_fn (name ):
270
+ def hook (module , input , output ):
271
+ shapes [name ] = output .shape if isinstance (output , torch .Tensor ) else \
272
+ {k : v .shape for k , v in output .items ()}
273
+ return hook
274
+
275
+ # Register shape tracking hooks
276
+ encoder .modality_encoders ['sat' ].register_forward_hook (hook_fn ('sat_encoder' ))
277
+ encoder .feature_projections ['sat' ].register_forward_hook (hook_fn ('sat_projection' ))
278
+ encoder .fusion_module .register_forward_hook (hook_fn ('fusion' ))
279
+
280
+ with torch .no_grad ():
281
+ output = encoder (minimal_inputs )
282
+
283
+ # Verify expected shapes
284
+ assert shapes ['fusion' ][1 ] == encoder .feature_dim
285
+ assert output .shape [1 ] == minimal_config ['out_features' ]
286
+
287
+
288
+ def test_modality_interactions (minimal_config , minimal_inputs ):
289
+ """ Test interaction between different modality combinations """
290
+
291
+ encoder = create_encoder (minimal_config )
292
+ batch_size = 2
293
+
294
+ # Test different modality combinations
295
+ test_cases = [
296
+ ({'sat' : minimal_inputs ['sat' ]}, "single_sat" ),
297
+ ({'pv' : minimal_inputs ['pv' ]}, "single_pv" ),
298
+ (minimal_inputs , "both" )
299
+ ]
300
+
301
+ outputs = {}
302
+ for inputs , case_name in test_cases :
303
+ with torch .no_grad ():
304
+ outputs [case_name ] = encoder (inputs )
305
+
306
+ # Verify outputs differ across modality combinations
307
+ assert not torch .allclose (outputs ['single_sat' ], outputs ['both' ])
308
+ assert not torch .allclose (outputs ['single_pv' ], outputs ['both' ])
309
+
310
+
311
+ def test_attention_behavior (minimal_config , minimal_inputs ):
312
+ """ Verify attention mechanism properties """
313
+
314
+ encoder = create_encoder (minimal_config )
315
+ attention_outputs = {}
316
+
317
+ def attention_hook (module , input , output ):
318
+ if isinstance (output , dict ):
319
+ attention_outputs .update (output )
320
+
321
+ if encoder .use_cross_attention :
322
+ encoder .cross_attention .register_forward_hook (attention_hook )
323
+
324
+ with torch .no_grad ():
325
+ encoder (minimal_inputs )
326
+
327
+ if attention_outputs :
328
+ # Verify attention weight distribution
329
+ for modality , features in attention_outputs .items ():
330
+ std = features .std ()
331
+ assert std > 1e-6 , "Attention weights too uniform"
332
+
333
+
334
+ @pytest .mark .parametrize ("noise_level" , [0.1 , 0.5 , 1.0 ])
335
+ def test_input_noise_robustness (minimal_config , minimal_inputs , noise_level ):
336
+ """ Test encoder stability under different noise levels """
337
+
338
+ encoder = create_encoder (minimal_config )
339
+
340
+ # Add controlled noise to inputs
341
+ noisy_inputs = {
342
+ k : v + noise_level * torch .randn_like (v )
343
+ for k , v in minimal_inputs .items ()
344
+ }
345
+
346
+ with torch .no_grad ():
347
+ clean_output = encoder (minimal_inputs )
348
+ noisy_output = encoder (noisy_inputs )
349
+
350
+ # Check output stability
351
+ relative_diff = (clean_output - noisy_output ).abs ().mean () / clean_output .abs ().mean ()
352
+ assert not torch .isnan (relative_diff )
353
+ assert not torch .isinf (relative_diff )
354
+ assert relative_diff < noise_level * 10
0 commit comments