9
9
from typing import Dict , Optional , List , Union
10
10
import torch
11
11
from torch import nn
12
+ import logging
12
13
13
14
from pvnet .models .multimodal .encoders .basic_blocks import AbstractNWPSatelliteEncoder
14
15
from pvnet .models .multimodal .fusion_blocks import DynamicFusionModule , ModalityGating
15
16
from pvnet .models .multimodal .attention_blocks import CrossModalAttention , SelfAttention
16
17
from pvnet .models .multimodal .encoders .encoders3d import DefaultPVNet2
17
18
18
19
20
+ logging .basicConfig (level = logging .DEBUG , format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' )
21
+ logger = logging .getLogger ('dynamic_encoder' )
22
+
23
+
19
24
# Attention head compatibility function
20
25
def get_compatible_heads (dim : int , target_heads : int ) -> int :
21
26
""" Calculate largest compatible number of heads <= target_heads """
22
27
28
+ logger .debug (f"Finding compatible heads for dim={ dim } , target_heads={ target_heads } " )
29
+
23
30
# Iterative reduction
24
31
# Obtain maximum divisible number of heads
25
32
# h ∈ ℕ : h ≤ target_heads ∧ dim mod h = 0
26
33
for h in range (min (target_heads , dim ), 0 , - 1 ):
27
34
if dim % h == 0 :
35
+ logger .debug (f"Selected compatible head count: { h } " )
28
36
return h
29
37
return 1
30
38
@@ -35,6 +43,7 @@ class PVEncoder(nn.Module):
35
43
36
44
def __init__ (self , sequence_length : int , num_sites : int , out_features : int ):
37
45
super ().__init__ ()
46
+ logger .info (f"Initialising PVEncoder with sequence_length={ sequence_length } , num_sites={ num_sites } " )
38
47
39
48
# Temporal and spatial configuration parameters
40
49
# L: sequence length
@@ -46,6 +55,7 @@ def __init__(self, sequence_length: int, num_sites: int, out_features: int):
46
55
# Basic feature extraction network
47
56
# φ: ℝ^M → ℝ^N
48
57
# Linear Transformation → Layer Normalization → ReLU → Dropout
58
+ logger .debug ("Creating encoder network" )
49
59
self .encoder = nn .Sequential (
50
60
nn .Linear (num_sites , out_features ),
51
61
nn .LayerNorm (out_features ),
@@ -57,13 +67,17 @@ def forward(self, x):
57
67
58
68
# Sequential processing - maintain temporal order
59
69
# x ∈ ℝ^{B×L×M} → out ∈ ℝ^{B×L×N}
70
+ logger .debug (f"PVEncoder input shape: { x .shape } " )
60
71
batch_size = x .shape [0 ]
61
72
out = []
62
73
for t in range (self .sequence_length ):
74
+ logger .debug (f"Processing timestep { t } " )
63
75
out .append (self .encoder (x [:, t ]))\
64
76
65
77
# Reshape maintaining sequence dimension
66
- return torch .stack (out , dim = 1 )
78
+ result = torch .stack (out , dim = 1 )
79
+ logger .debug (f"PVEncoder output shape: { result .shape } " )
80
+ return result
67
81
68
82
69
83
# Primary fusion encoder implementation
@@ -87,6 +101,8 @@ def __init__(
87
101
):
88
102
""" Dynamic fusion encoder initialisation """
89
103
104
+ logger .info (f"Initialising DynamicFusionEncoder with sequence_length={ sequence_length } , out_features={ out_features } " )
105
+
90
106
super ().__init__ (
91
107
sequence_length = sequence_length ,
92
108
image_size_pixels = image_size_pixels ,
@@ -96,10 +112,12 @@ def __init__(
96
112
97
113
# Dimension validation and compatibility
98
114
# Adjust hidden dimension to be divisible by sequence length
99
- # H = feature_dim × sequence_length
115
+ # H = feature_dim × sequence_length
116
+ logger .debug (f"Initial hidden_dim: { hidden_dim } " )
100
117
if hidden_dim % sequence_length != 0 :
101
118
feature_dim = ((hidden_dim + sequence_length - 1 ) // sequence_length )
102
119
hidden_dim = feature_dim * sequence_length
120
+ logger .debug (f"Adjusted hidden_dim to { hidden_dim } for sequence length compatibility" )
103
121
else :
104
122
feature_dim = hidden_dim // sequence_length
105
123
@@ -108,15 +126,18 @@ def __init__(
108
126
# h ∈ ℕ : h ≤ num_heads ∧ feature_dim mod h = 0
109
127
attention_heads = cross_attention .get ('num_heads' , num_heads )
110
128
attention_heads = get_compatible_heads (feature_dim , attention_heads )
111
-
129
+ logger .debug (f"Using { attention_heads } attention heads" )
130
+
112
131
# Dimension adjustment for attention mechanism
113
132
# Ensure feature dimension is compatible with attention heads
114
133
if feature_dim < attention_heads :
115
134
feature_dim = attention_heads
116
135
hidden_dim = feature_dim * sequence_length
136
+ logger .debug (f"Adjusted dimensions - feature_dim: { feature_dim } , hidden_dim: { hidden_dim } " )
117
137
elif feature_dim % attention_heads != 0 :
118
138
feature_dim = ((feature_dim + attention_heads - 1 ) // attention_heads ) * attention_heads
119
139
hidden_dim = feature_dim * sequence_length
140
+ logger .debug (f"Adjusted for attention compatibility - feature_dim: { feature_dim } , hidden_dim: { hidden_dim } " )
120
141
121
142
# Architecture dimensions
122
143
self .feature_dim = feature_dim
@@ -129,6 +150,7 @@ def __init__(
129
150
dynamic_fusion ['num_heads' ] = attention_heads
130
151
131
152
# Modality specific encoder instantiation
153
+ logger .debug ("Creating modality encoders" )
132
154
self .modality_encoders = nn .ModuleDict ()
133
155
for modality , config in modality_encoders .items ():
134
156
config = config .copy ()
@@ -161,6 +183,7 @@ def __init__(
161
183
)
162
184
163
185
# Feature transformation layers
186
+ logger .debug ("Creating feature projections" )
164
187
self .feature_projections = nn .ModuleDict ({
165
188
modality : nn .Sequential (
166
189
nn .LayerNorm (feature_dim ),
@@ -174,6 +197,7 @@ def __init__(
174
197
# Modality gating mechanism
175
198
self .use_gating = use_gating
176
199
if use_gating :
200
+ logger .debug ("Initialising gating mechanism" )
177
201
gating_config = modality_gating .copy ()
178
202
gating_config .update ({
179
203
'feature_dims' : {mod : feature_dim for mod in modality_channels .keys ()},
@@ -184,6 +208,7 @@ def __init__(
184
208
# Cross modal attention mechanism
185
209
self .use_cross_attention = use_cross_attention and len (modality_channels ) > 1
186
210
if self .use_cross_attention :
211
+ logger .debug ("Initialising cross attention" )
187
212
attention_config = cross_attention .copy ()
188
213
attention_config .update ({
189
214
'embed_dim' : feature_dim ,
@@ -194,6 +219,7 @@ def __init__(
194
219
self .cross_attention = CrossModalAttention (** attention_config )
195
220
196
221
# Dynamic fusion implementation
222
+ logger .debug ("Initialising fusion module" )
197
223
fusion_config = dynamic_fusion .copy ()
198
224
fusion_config .update ({
199
225
'feature_dims' : {mod : feature_dim for mod in modality_channels .keys ()},
@@ -204,6 +230,7 @@ def __init__(
204
230
self .fusion_module = DynamicFusionModule (** fusion_config )
205
231
206
232
# Output network definition
233
+ logger .debug ("Creating final output block" )
207
234
self .final_block = nn .Sequential (
208
235
nn .Linear (hidden_dim , fc_features ),
209
236
nn .LayerNorm (fc_features ),
@@ -219,6 +246,9 @@ def forward(
219
246
mask : Optional [torch .Tensor ] = None
220
247
) -> torch .Tensor :
221
248
249
+ logger .info ("Starting DynamicFusionEncoder forward pass" )
250
+ logger .debug (f"Input modalities: { list (inputs .keys ())} " )
251
+
222
252
# Encoded features dictionary
223
253
# M ∈ {x_m | m ∈ Modalities}
224
254
encoded_features = {}
@@ -230,80 +260,91 @@ def forward(
230
260
continue
231
261
232
262
# Feature extraction and projection
263
+ logger .debug (f"Encoding { modality } input of shape { x .shape } " )
233
264
encoded = self .modality_encoders [modality ](x )
234
- print (f"Encoded { modality } shape: { encoded .shape } " )
265
+ logger . debug (f"Encoded { modality } shape: { encoded .shape } " )
235
266
236
267
# Temporal projection across sequence
237
268
# π: ℝ^{B×L×D} → ℝ^{B×L×D}
238
269
projected = torch .stack ([
239
270
self .feature_projections [modality ](encoded [:, t ])
240
271
for t in range (self .sequence_length )
241
272
], dim = 1 )
242
- print (f"Projected { modality } shape: { projected .shape } " )
273
+ logger . debug (f"Projected { modality } shape: { projected .shape } " )
243
274
244
275
encoded_features [modality ] = projected
245
276
246
277
# Validation of encoded feature space
247
278
# |M| > 0
248
279
if not encoded_features :
249
- raise ValueError ("No valid features after encoding" )
280
+ error_msg = "No valid features after encoding"
281
+ logger .error (error_msg )
282
+ raise ValueError (error_msg )
250
283
251
284
# Apply modality interaction mechanisms
252
285
if self .use_gating :
253
286
254
287
# g: M → M̂
255
288
# Adaptive feature transformation with learned gates
289
+ logger .debug ("Applying modality gating" )
256
290
encoded_features = self .gating (encoded_features )
257
- print (f"After gating, encoded_features shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
291
+ logger . debug (f"After gating shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
258
292
259
293
# Cross-modal attention mechanism
260
294
if self .use_cross_attention :
261
295
if len (encoded_features ) > 1 :
262
-
296
+ logger . debug ( "Applying cross-modal attention" )
263
297
# Multi-modal cross attention
264
298
encoded_features = self .cross_attention (encoded_features , mask )
265
299
else :
266
-
300
+ logger . debug ( "Single modality: skipping cross-attention" )
267
301
# For single modality, apply self-attention instead
268
302
for key in encoded_features :
269
303
encoded_features [key ] = encoded_features [key ] # Identity mapping
270
304
271
- print (f"After cross-modal attention - encoded_features shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
305
+ logger . debug (f"After attention shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
272
306
273
307
# Feature fusion and output generation
308
+ logger .debug ("Applying fusion module" )
274
309
fused_features = self .fusion_module (encoded_features , mask )
275
- print (f"Fused features shape: { fused_features .shape } " )
310
+ logger . debug (f"Fused features shape: { fused_features .shape } " )
276
311
277
312
# Ensure input to final_block matches hidden_dim
278
313
# Ensure z ∈ ℝ^{B×H}, H: hidden dimension
279
314
batch_size = fused_features .size (0 )
280
315
281
316
# Repeat the features to match the expected hidden dimension
282
317
if fused_features .size (1 ) != self .hidden_dim :
318
+ logger .debug ("Adjusting fused features dimension" )
283
319
fused_features = fused_features .repeat (1 , self .hidden_dim // fused_features .size (1 ))
284
320
285
321
# Precision projection if dimension mismatch persists
286
322
# π_H: ℝ^k → ℝ^H
287
323
if fused_features .size (1 ) != self .hidden_dim :
324
+ logger .debug ("Creating precision projection" )
288
325
projection = nn .Linear (fused_features .size (1 ), self .hidden_dim ).to (fused_features .device )
289
326
fused_features = projection (fused_features )
290
327
291
328
# Final output generation
292
329
# ψ: ℝ^H → ℝ^M, M: output features
293
330
output = self .final_block (fused_features )
294
-
331
+ logger .debug (f"Final output shape: { output .shape } " )
332
+
295
333
return output
296
334
297
335
298
336
class DynamicResidualEncoder (DynamicFusionEncoder ):
299
337
""" Dynamic fusion implementation with residual connectivity """
300
338
301
339
def __init__ (self , * args , ** kwargs ):
340
+
341
+ logger .info ("Initialising DynamicResidualEncoder" )
302
342
super ().__init__ (* args , ** kwargs )
303
343
304
344
# Enhanced projection with residual pathways
305
345
# With residual transformation
306
346
# φ_m: ℝ^H → ℝ^H
347
+ logger .debug ("Creating residual feature projections" )
307
348
self .feature_projections = nn .ModuleDict ({
308
349
modality : nn .Sequential (
309
350
nn .LayerNorm (self .hidden_dim ),
@@ -322,6 +363,9 @@ def forward(
322
363
mask : Optional [torch .Tensor ] = None
323
364
) -> torch .Tensor :
324
365
366
+ logger .info ("Starting DynamicResidualEncoder forward pass" )
367
+ logger .debug (f"Input modalities: { list (inputs .keys ())} " )
368
+
325
369
""" Forward implementation with residual pathways """
326
370
327
371
# Encoded features dictionary
@@ -333,35 +377,48 @@ def forward(
333
377
if modality not in self .modality_encoders or x is None :
334
378
continue
335
379
380
+ logger .debug (f"Processing { modality } with shape { x .shape } " )
336
381
encoded = self .modality_encoders [modality ](x )
382
+ logger .debug (f"Encoded shape: { encoded .shape } " )
337
383
338
384
# Residual connection
339
385
# x_m ⊕ R_m(x_m)
340
386
projected = encoded + self .feature_projections [modality ](encoded )
387
+ logger .debug (f"Projected shape with residual: { projected .shape } " )
341
388
encoded_features [modality ] = projected
342
389
343
390
if not encoded_features :
344
- raise ValueError ("No valid features after encoding" )
391
+ error_msg = "No valid features after encoding"
392
+ logger .error (error_msg )
393
+ raise ValueError (error_msg )
345
394
346
395
# Gating with residual pathways
347
396
# g_m: x_m ⊕ g(x_m)
348
397
if self .use_gating :
398
+ logger .debug ("Applying gating with residual connections" )
349
399
gated_features = self .gating (encoded_features )
350
400
for modality in encoded_features :
351
401
gated_features [modality ] = gated_features [modality ] + encoded_features [modality ]
352
402
encoded_features = gated_features
353
-
403
+ logger .debug (f"After gating shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
404
+
354
405
# Attention with residual pathways
355
406
# A_m: x_m ⊕ A(x_m)
356
407
if self .use_cross_attention and len (encoded_features ) > 1 :
408
+ logger .debug ("Applying cross-attention with residual connections" )
357
409
attended_features = self .cross_attention (encoded_features , mask )
358
410
for modality in encoded_features :
359
411
attended_features [modality ] = attended_features [modality ] + encoded_features [modality ]
360
412
encoded_features = attended_features
361
-
413
+ logger .debug (f"After attention shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
414
+
362
415
# Final fusion and output generation
416
+ logger .debug ("Applying fusion module" )
363
417
fused_features = self .fusion_module (encoded_features , mask )
364
418
fused_features = fused_features .repeat (1 , self .sequence_length )
419
+ logger .debug (f"Fused features shape: { fused_features .shape } " )
420
+
365
421
output = self .final_block (fused_features )
366
-
422
+ logger .debug (f"Final output shape: { output .shape } " )
423
+
367
424
return output
0 commit comments