99from typing import Dict , Optional , List , Union
1010import torch
1111from torch import nn
12+ import logging
1213
1314from pvnet .models .multimodal .encoders .basic_blocks import AbstractNWPSatelliteEncoder
1415from pvnet .models .multimodal .fusion_blocks import DynamicFusionModule , ModalityGating
1516from pvnet .models .multimodal .attention_blocks import CrossModalAttention , SelfAttention
1617from pvnet .models .multimodal .encoders .encoders3d import DefaultPVNet2
1718
1819
20+ logging .basicConfig (level = logging .DEBUG , format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' )
21+ logger = logging .getLogger ('dynamic_encoder' )
22+
23+
1924# Attention head compatibility function
2025def get_compatible_heads (dim : int , target_heads : int ) -> int :
2126 """ Calculate largest compatible number of heads <= target_heads """
2227
28+ logger .debug (f"Finding compatible heads for dim={ dim } , target_heads={ target_heads } " )
29+
2330 # Iterative reduction
2431 # Obtain maximum divisible number of heads
2532 # h ∈ ℕ : h ≤ target_heads ∧ dim mod h = 0
2633 for h in range (min (target_heads , dim ), 0 , - 1 ):
2734 if dim % h == 0 :
35+ logger .debug (f"Selected compatible head count: { h } " )
2836 return h
2937 return 1
3038
@@ -35,6 +43,7 @@ class PVEncoder(nn.Module):
3543
3644 def __init__ (self , sequence_length : int , num_sites : int , out_features : int ):
3745 super ().__init__ ()
46+ logger .info (f"Initialising PVEncoder with sequence_length={ sequence_length } , num_sites={ num_sites } " )
3847
3948 # Temporal and spatial configuration parameters
4049 # L: sequence length
@@ -46,6 +55,7 @@ def __init__(self, sequence_length: int, num_sites: int, out_features: int):
4655 # Basic feature extraction network
4756 # φ: ℝ^M → ℝ^N
4857 # Linear Transformation → Layer Normalization → ReLU → Dropout
58+ logger .debug ("Creating encoder network" )
4959 self .encoder = nn .Sequential (
5060 nn .Linear (num_sites , out_features ),
5161 nn .LayerNorm (out_features ),
@@ -57,13 +67,17 @@ def forward(self, x):
5767
5868 # Sequential processing - maintain temporal order
5969 # x ∈ ℝ^{B×L×M} → out ∈ ℝ^{B×L×N}
70+ logger .debug (f"PVEncoder input shape: { x .shape } " )
6071 batch_size = x .shape [0 ]
6172 out = []
6273 for t in range (self .sequence_length ):
74+ logger .debug (f"Processing timestep { t } " )
6375 out .append (self .encoder (x [:, t ]))\
6476
6577 # Reshape maintaining sequence dimension
66- return torch .stack (out , dim = 1 )
78+ result = torch .stack (out , dim = 1 )
79+ logger .debug (f"PVEncoder output shape: { result .shape } " )
80+ return result
6781
6882
6983# Primary fusion encoder implementation
@@ -87,6 +101,8 @@ def __init__(
87101 ):
88102 """ Dynamic fusion encoder initialisation """
89103
104+ logger .info (f"Initialising DynamicFusionEncoder with sequence_length={ sequence_length } , out_features={ out_features } " )
105+
90106 super ().__init__ (
91107 sequence_length = sequence_length ,
92108 image_size_pixels = image_size_pixels ,
@@ -96,10 +112,12 @@ def __init__(
96112
97113 # Dimension validation and compatibility
98114 # Adjust hidden dimension to be divisible by sequence length
99- # H = feature_dim × sequence_length
115+ # H = feature_dim × sequence_length
116+ logger .debug (f"Initial hidden_dim: { hidden_dim } " )
100117 if hidden_dim % sequence_length != 0 :
101118 feature_dim = ((hidden_dim + sequence_length - 1 ) // sequence_length )
102119 hidden_dim = feature_dim * sequence_length
120+ logger .debug (f"Adjusted hidden_dim to { hidden_dim } for sequence length compatibility" )
103121 else :
104122 feature_dim = hidden_dim // sequence_length
105123
@@ -108,15 +126,18 @@ def __init__(
108126 # h ∈ ℕ : h ≤ num_heads ∧ feature_dim mod h = 0
109127 attention_heads = cross_attention .get ('num_heads' , num_heads )
110128 attention_heads = get_compatible_heads (feature_dim , attention_heads )
111-
129+ logger .debug (f"Using { attention_heads } attention heads" )
130+
112131 # Dimension adjustment for attention mechanism
113132 # Ensure feature dimension is compatible with attention heads
114133 if feature_dim < attention_heads :
115134 feature_dim = attention_heads
116135 hidden_dim = feature_dim * sequence_length
136+ logger .debug (f"Adjusted dimensions - feature_dim: { feature_dim } , hidden_dim: { hidden_dim } " )
117137 elif feature_dim % attention_heads != 0 :
118138 feature_dim = ((feature_dim + attention_heads - 1 ) // attention_heads ) * attention_heads
119139 hidden_dim = feature_dim * sequence_length
140+ logger .debug (f"Adjusted for attention compatibility - feature_dim: { feature_dim } , hidden_dim: { hidden_dim } " )
120141
121142 # Architecture dimensions
122143 self .feature_dim = feature_dim
@@ -129,6 +150,7 @@ def __init__(
129150 dynamic_fusion ['num_heads' ] = attention_heads
130151
131152 # Modality specific encoder instantiation
153+ logger .debug ("Creating modality encoders" )
132154 self .modality_encoders = nn .ModuleDict ()
133155 for modality , config in modality_encoders .items ():
134156 config = config .copy ()
@@ -161,6 +183,7 @@ def __init__(
161183 )
162184
163185 # Feature transformation layers
186+ logger .debug ("Creating feature projections" )
164187 self .feature_projections = nn .ModuleDict ({
165188 modality : nn .Sequential (
166189 nn .LayerNorm (feature_dim ),
@@ -174,6 +197,7 @@ def __init__(
174197 # Modality gating mechanism
175198 self .use_gating = use_gating
176199 if use_gating :
200+ logger .debug ("Initialising gating mechanism" )
177201 gating_config = modality_gating .copy ()
178202 gating_config .update ({
179203 'feature_dims' : {mod : feature_dim for mod in modality_channels .keys ()},
@@ -184,6 +208,7 @@ def __init__(
184208 # Cross modal attention mechanism
185209 self .use_cross_attention = use_cross_attention and len (modality_channels ) > 1
186210 if self .use_cross_attention :
211+ logger .debug ("Initialising cross attention" )
187212 attention_config = cross_attention .copy ()
188213 attention_config .update ({
189214 'embed_dim' : feature_dim ,
@@ -194,6 +219,7 @@ def __init__(
194219 self .cross_attention = CrossModalAttention (** attention_config )
195220
196221 # Dynamic fusion implementation
222+ logger .debug ("Initialising fusion module" )
197223 fusion_config = dynamic_fusion .copy ()
198224 fusion_config .update ({
199225 'feature_dims' : {mod : feature_dim for mod in modality_channels .keys ()},
@@ -204,6 +230,7 @@ def __init__(
204230 self .fusion_module = DynamicFusionModule (** fusion_config )
205231
206232 # Output network definition
233+ logger .debug ("Creating final output block" )
207234 self .final_block = nn .Sequential (
208235 nn .Linear (hidden_dim , fc_features ),
209236 nn .LayerNorm (fc_features ),
@@ -219,6 +246,9 @@ def forward(
219246 mask : Optional [torch .Tensor ] = None
220247 ) -> torch .Tensor :
221248
249+ logger .info ("Starting DynamicFusionEncoder forward pass" )
250+ logger .debug (f"Input modalities: { list (inputs .keys ())} " )
251+
222252 # Encoded features dictionary
223253 # M ∈ {x_m | m ∈ Modalities}
224254 encoded_features = {}
@@ -230,80 +260,91 @@ def forward(
230260 continue
231261
232262 # Feature extraction and projection
263+ logger .debug (f"Encoding { modality } input of shape { x .shape } " )
233264 encoded = self .modality_encoders [modality ](x )
234- print (f"Encoded { modality } shape: { encoded .shape } " )
265+ logger . debug (f"Encoded { modality } shape: { encoded .shape } " )
235266
236267 # Temporal projection across sequence
237268 # π: ℝ^{B×L×D} → ℝ^{B×L×D}
238269 projected = torch .stack ([
239270 self .feature_projections [modality ](encoded [:, t ])
240271 for t in range (self .sequence_length )
241272 ], dim = 1 )
242- print (f"Projected { modality } shape: { projected .shape } " )
273+ logger . debug (f"Projected { modality } shape: { projected .shape } " )
243274
244275 encoded_features [modality ] = projected
245276
246277 # Validation of encoded feature space
247278 # |M| > 0
248279 if not encoded_features :
249- raise ValueError ("No valid features after encoding" )
280+ error_msg = "No valid features after encoding"
281+ logger .error (error_msg )
282+ raise ValueError (error_msg )
250283
251284 # Apply modality interaction mechanisms
252285 if self .use_gating :
253286
254287 # g: M → M̂
255288 # Adaptive feature transformation with learned gates
289+ logger .debug ("Applying modality gating" )
256290 encoded_features = self .gating (encoded_features )
257- print (f"After gating, encoded_features shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
291+ logger . debug (f"After gating shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
258292
259293 # Cross-modal attention mechanism
260294 if self .use_cross_attention :
261295 if len (encoded_features ) > 1 :
262-
296+ logger . debug ( "Applying cross-modal attention" )
263297 # Multi-modal cross attention
264298 encoded_features = self .cross_attention (encoded_features , mask )
265299 else :
266-
300+ logger . debug ( "Single modality: skipping cross-attention" )
267301 # For single modality, apply self-attention instead
268302 for key in encoded_features :
269303 encoded_features [key ] = encoded_features [key ] # Identity mapping
270304
271- print (f"After cross-modal attention - encoded_features shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
305+ logger . debug (f"After attention shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
272306
273307 # Feature fusion and output generation
308+ logger .debug ("Applying fusion module" )
274309 fused_features = self .fusion_module (encoded_features , mask )
275- print (f"Fused features shape: { fused_features .shape } " )
310+ logger . debug (f"Fused features shape: { fused_features .shape } " )
276311
277312 # Ensure input to final_block matches hidden_dim
278313 # Ensure z ∈ ℝ^{B×H}, H: hidden dimension
279314 batch_size = fused_features .size (0 )
280315
281316 # Repeat the features to match the expected hidden dimension
282317 if fused_features .size (1 ) != self .hidden_dim :
318+ logger .debug ("Adjusting fused features dimension" )
283319 fused_features = fused_features .repeat (1 , self .hidden_dim // fused_features .size (1 ))
284320
285321 # Precision projection if dimension mismatch persists
286322 # π_H: ℝ^k → ℝ^H
287323 if fused_features .size (1 ) != self .hidden_dim :
324+ logger .debug ("Creating precision projection" )
288325 projection = nn .Linear (fused_features .size (1 ), self .hidden_dim ).to (fused_features .device )
289326 fused_features = projection (fused_features )
290327
291328 # Final output generation
292329 # ψ: ℝ^H → ℝ^M, M: output features
293330 output = self .final_block (fused_features )
294-
331+ logger .debug (f"Final output shape: { output .shape } " )
332+
295333 return output
296334
297335
298336class DynamicResidualEncoder (DynamicFusionEncoder ):
299337 """ Dynamic fusion implementation with residual connectivity """
300338
301339 def __init__ (self , * args , ** kwargs ):
340+
341+ logger .info ("Initialising DynamicResidualEncoder" )
302342 super ().__init__ (* args , ** kwargs )
303343
304344 # Enhanced projection with residual pathways
305345 # With residual transformation
306346 # φ_m: ℝ^H → ℝ^H
347+ logger .debug ("Creating residual feature projections" )
307348 self .feature_projections = nn .ModuleDict ({
308349 modality : nn .Sequential (
309350 nn .LayerNorm (self .hidden_dim ),
@@ -322,6 +363,9 @@ def forward(
322363 mask : Optional [torch .Tensor ] = None
323364 ) -> torch .Tensor :
324365
366+ logger .info ("Starting DynamicResidualEncoder forward pass" )
367+ logger .debug (f"Input modalities: { list (inputs .keys ())} " )
368+
325369 """ Forward implementation with residual pathways """
326370
327371 # Encoded features dictionary
@@ -333,35 +377,48 @@ def forward(
333377 if modality not in self .modality_encoders or x is None :
334378 continue
335379
380+ logger .debug (f"Processing { modality } with shape { x .shape } " )
336381 encoded = self .modality_encoders [modality ](x )
382+ logger .debug (f"Encoded shape: { encoded .shape } " )
337383
338384 # Residual connection
339385 # x_m ⊕ R_m(x_m)
340386 projected = encoded + self .feature_projections [modality ](encoded )
387+ logger .debug (f"Projected shape with residual: { projected .shape } " )
341388 encoded_features [modality ] = projected
342389
343390 if not encoded_features :
344- raise ValueError ("No valid features after encoding" )
391+ error_msg = "No valid features after encoding"
392+ logger .error (error_msg )
393+ raise ValueError (error_msg )
345394
346395 # Gating with residual pathways
347396 # g_m: x_m ⊕ g(x_m)
348397 if self .use_gating :
398+ logger .debug ("Applying gating with residual connections" )
349399 gated_features = self .gating (encoded_features )
350400 for modality in encoded_features :
351401 gated_features [modality ] = gated_features [modality ] + encoded_features [modality ]
352402 encoded_features = gated_features
353-
403+ logger .debug (f"After gating shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
404+
354405 # Attention with residual pathways
355406 # A_m: x_m ⊕ A(x_m)
356407 if self .use_cross_attention and len (encoded_features ) > 1 :
408+ logger .debug ("Applying cross-attention with residual connections" )
357409 attended_features = self .cross_attention (encoded_features , mask )
358410 for modality in encoded_features :
359411 attended_features [modality ] = attended_features [modality ] + encoded_features [modality ]
360412 encoded_features = attended_features
361-
413+ logger .debug (f"After attention shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
414+
362415 # Final fusion and output generation
416+ logger .debug ("Applying fusion module" )
363417 fused_features = self .fusion_module (encoded_features , mask )
364418 fused_features = fused_features .repeat (1 , self .sequence_length )
419+ logger .debug (f"Fused features shape: { fused_features .shape } " )
420+
365421 output = self .final_block (fused_features )
366-
422+ logger .debug (f"Final output shape: { output .shape } " )
423+
367424 return output
0 commit comments