Skip to content

Commit e644d6e

Browse files
committed
Dynamic encoder update - debug
1 parent 8cd6165 commit e644d6e

File tree

1 file changed

+73
-16
lines changed

1 file changed

+73
-16
lines changed

pvnet/models/multimodal/encoders/dynamic_encoder.py

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,30 @@
99
from typing import Dict, Optional, List, Union
1010
import torch
1111
from torch import nn
12+
import logging
1213

1314
from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder
1415
from pvnet.models.multimodal.fusion_blocks import DynamicFusionModule, ModalityGating
1516
from pvnet.models.multimodal.attention_blocks import CrossModalAttention, SelfAttention
1617
from pvnet.models.multimodal.encoders.encoders3d import DefaultPVNet2
1718

1819

20+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21+
logger = logging.getLogger('dynamic_encoder')
22+
23+
1924
# Attention head compatibility function
2025
def get_compatible_heads(dim: int, target_heads: int) -> int:
2126
""" Calculate largest compatible number of heads <= target_heads """
2227

28+
logger.debug(f"Finding compatible heads for dim={dim}, target_heads={target_heads}")
29+
2330
# Iterative reduction
2431
# Obtain maximum divisible number of heads
2532
# h ∈ ℕ : h ≤ target_heads ∧ dim mod h = 0
2633
for h in range(min(target_heads, dim), 0, -1):
2734
if dim % h == 0:
35+
logger.debug(f"Selected compatible head count: {h}")
2836
return h
2937
return 1
3038

@@ -35,6 +43,7 @@ class PVEncoder(nn.Module):
3543

3644
def __init__(self, sequence_length: int, num_sites: int, out_features: int):
3745
super().__init__()
46+
logger.info(f"Initialising PVEncoder with sequence_length={sequence_length}, num_sites={num_sites}")
3847

3948
# Temporal and spatial configuration parameters
4049
# L: sequence length
@@ -46,6 +55,7 @@ def __init__(self, sequence_length: int, num_sites: int, out_features: int):
4655
# Basic feature extraction network
4756
# φ: ℝ^M → ℝ^N
4857
# Linear Transformation → Layer Normalization → ReLU → Dropout
58+
logger.debug("Creating encoder network")
4959
self.encoder = nn.Sequential(
5060
nn.Linear(num_sites, out_features),
5161
nn.LayerNorm(out_features),
@@ -57,13 +67,17 @@ def forward(self, x):
5767

5868
# Sequential processing - maintain temporal order
5969
# x ∈ ℝ^{B×L×M} → out ∈ ℝ^{B×L×N}
70+
logger.debug(f"PVEncoder input shape: {x.shape}")
6071
batch_size = x.shape[0]
6172
out = []
6273
for t in range(self.sequence_length):
74+
logger.debug(f"Processing timestep {t}")
6375
out.append(self.encoder(x[:, t]))\
6476

6577
# Reshape maintaining sequence dimension
66-
return torch.stack(out, dim=1)
78+
result = torch.stack(out, dim=1)
79+
logger.debug(f"PVEncoder output shape: {result.shape}")
80+
return result
6781

6882

6983
# Primary fusion encoder implementation
@@ -87,6 +101,8 @@ def __init__(
87101
):
88102
""" Dynamic fusion encoder initialisation """
89103

104+
logger.info(f"Initialising DynamicFusionEncoder with sequence_length={sequence_length}, out_features={out_features}")
105+
90106
super().__init__(
91107
sequence_length=sequence_length,
92108
image_size_pixels=image_size_pixels,
@@ -96,10 +112,12 @@ def __init__(
96112

97113
# Dimension validation and compatibility
98114
# Adjust hidden dimension to be divisible by sequence length
99-
# H = feature_dim × sequence_length
115+
# H = feature_dim × sequence_length
116+
logger.debug(f"Initial hidden_dim: {hidden_dim}")
100117
if hidden_dim % sequence_length != 0:
101118
feature_dim = ((hidden_dim + sequence_length - 1) // sequence_length)
102119
hidden_dim = feature_dim * sequence_length
120+
logger.debug(f"Adjusted hidden_dim to {hidden_dim} for sequence length compatibility")
103121
else:
104122
feature_dim = hidden_dim // sequence_length
105123

@@ -108,15 +126,18 @@ def __init__(
108126
# h ∈ ℕ : h ≤ num_heads ∧ feature_dim mod h = 0
109127
attention_heads = cross_attention.get('num_heads', num_heads)
110128
attention_heads = get_compatible_heads(feature_dim, attention_heads)
111-
129+
logger.debug(f"Using {attention_heads} attention heads")
130+
112131
# Dimension adjustment for attention mechanism
113132
# Ensure feature dimension is compatible with attention heads
114133
if feature_dim < attention_heads:
115134
feature_dim = attention_heads
116135
hidden_dim = feature_dim * sequence_length
136+
logger.debug(f"Adjusted dimensions - feature_dim: {feature_dim}, hidden_dim: {hidden_dim}")
117137
elif feature_dim % attention_heads != 0:
118138
feature_dim = ((feature_dim + attention_heads - 1) // attention_heads) * attention_heads
119139
hidden_dim = feature_dim * sequence_length
140+
logger.debug(f"Adjusted for attention compatibility - feature_dim: {feature_dim}, hidden_dim: {hidden_dim}")
120141

121142
# Architecture dimensions
122143
self.feature_dim = feature_dim
@@ -129,6 +150,7 @@ def __init__(
129150
dynamic_fusion['num_heads'] = attention_heads
130151

131152
# Modality specific encoder instantiation
153+
logger.debug("Creating modality encoders")
132154
self.modality_encoders = nn.ModuleDict()
133155
for modality, config in modality_encoders.items():
134156
config = config.copy()
@@ -161,6 +183,7 @@ def __init__(
161183
)
162184

163185
# Feature transformation layers
186+
logger.debug("Creating feature projections")
164187
self.feature_projections = nn.ModuleDict({
165188
modality: nn.Sequential(
166189
nn.LayerNorm(feature_dim),
@@ -174,6 +197,7 @@ def __init__(
174197
# Modality gating mechanism
175198
self.use_gating = use_gating
176199
if use_gating:
200+
logger.debug("Initialising gating mechanism")
177201
gating_config = modality_gating.copy()
178202
gating_config.update({
179203
'feature_dims': {mod: feature_dim for mod in modality_channels.keys()},
@@ -184,6 +208,7 @@ def __init__(
184208
# Cross modal attention mechanism
185209
self.use_cross_attention = use_cross_attention and len(modality_channels) > 1
186210
if self.use_cross_attention:
211+
logger.debug("Initialising cross attention")
187212
attention_config = cross_attention.copy()
188213
attention_config.update({
189214
'embed_dim': feature_dim,
@@ -194,6 +219,7 @@ def __init__(
194219
self.cross_attention = CrossModalAttention(**attention_config)
195220

196221
# Dynamic fusion implementation
222+
logger.debug("Initialising fusion module")
197223
fusion_config = dynamic_fusion.copy()
198224
fusion_config.update({
199225
'feature_dims': {mod: feature_dim for mod in modality_channels.keys()},
@@ -204,6 +230,7 @@ def __init__(
204230
self.fusion_module = DynamicFusionModule(**fusion_config)
205231

206232
# Output network definition
233+
logger.debug("Creating final output block")
207234
self.final_block = nn.Sequential(
208235
nn.Linear(hidden_dim, fc_features),
209236
nn.LayerNorm(fc_features),
@@ -219,6 +246,9 @@ def forward(
219246
mask: Optional[torch.Tensor] = None
220247
) -> torch.Tensor:
221248

249+
logger.info("Starting DynamicFusionEncoder forward pass")
250+
logger.debug(f"Input modalities: {list(inputs.keys())}")
251+
222252
# Encoded features dictionary
223253
# M ∈ {x_m | m ∈ Modalities}
224254
encoded_features = {}
@@ -230,80 +260,91 @@ def forward(
230260
continue
231261

232262
# Feature extraction and projection
263+
logger.debug(f"Encoding {modality} input of shape {x.shape}")
233264
encoded = self.modality_encoders[modality](x)
234-
print(f"Encoded {modality} shape: {encoded.shape}")
265+
logger.debug(f"Encoded {modality} shape: {encoded.shape}")
235266

236267
# Temporal projection across sequence
237268
# π: ℝ^{B×L×D} → ℝ^{B×L×D}
238269
projected = torch.stack([
239270
self.feature_projections[modality](encoded[:, t])
240271
for t in range(self.sequence_length)
241272
], dim=1)
242-
print(f"Projected {modality} shape: {projected.shape}")
273+
logger.debug(f"Projected {modality} shape: {projected.shape}")
243274

244275
encoded_features[modality] = projected
245276

246277
# Validation of encoded feature space
247278
# |M| > 0
248279
if not encoded_features:
249-
raise ValueError("No valid features after encoding")
280+
error_msg = "No valid features after encoding"
281+
logger.error(error_msg)
282+
raise ValueError(error_msg)
250283

251284
# Apply modality interaction mechanisms
252285
if self.use_gating:
253286

254287
# g: M → M̂
255288
# Adaptive feature transformation with learned gates
289+
logger.debug("Applying modality gating")
256290
encoded_features = self.gating(encoded_features)
257-
print(f"After gating, encoded_features shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
291+
logger.debug(f"After gating shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
258292

259293
# Cross-modal attention mechanism
260294
if self.use_cross_attention:
261295
if len(encoded_features) > 1:
262-
296+
logger.debug("Applying cross-modal attention")
263297
# Multi-modal cross attention
264298
encoded_features = self.cross_attention(encoded_features, mask)
265299
else:
266-
300+
logger.debug("Single modality: skipping cross-attention")
267301
# For single modality, apply self-attention instead
268302
for key in encoded_features:
269303
encoded_features[key] = encoded_features[key] # Identity mapping
270304

271-
print(f"After cross-modal attention - encoded_features shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
305+
logger.debug(f"After attention shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
272306

273307
# Feature fusion and output generation
308+
logger.debug("Applying fusion module")
274309
fused_features = self.fusion_module(encoded_features, mask)
275-
print(f"Fused features shape: {fused_features.shape}")
310+
logger.debug(f"Fused features shape: {fused_features.shape}")
276311

277312
# Ensure input to final_block matches hidden_dim
278313
# Ensure z ∈ ℝ^{B×H}, H: hidden dimension
279314
batch_size = fused_features.size(0)
280315

281316
# Repeat the features to match the expected hidden dimension
282317
if fused_features.size(1) != self.hidden_dim:
318+
logger.debug("Adjusting fused features dimension")
283319
fused_features = fused_features.repeat(1, self.hidden_dim // fused_features.size(1))
284320

285321
# Precision projection if dimension mismatch persists
286322
# π_H: ℝ^k → ℝ^H
287323
if fused_features.size(1) != self.hidden_dim:
324+
logger.debug("Creating precision projection")
288325
projection = nn.Linear(fused_features.size(1), self.hidden_dim).to(fused_features.device)
289326
fused_features = projection(fused_features)
290327

291328
# Final output generation
292329
# ψ: ℝ^H → ℝ^M, M: output features
293330
output = self.final_block(fused_features)
294-
331+
logger.debug(f"Final output shape: {output.shape}")
332+
295333
return output
296334

297335

298336
class DynamicResidualEncoder(DynamicFusionEncoder):
299337
""" Dynamic fusion implementation with residual connectivity """
300338

301339
def __init__(self, *args, **kwargs):
340+
341+
logger.info("Initialising DynamicResidualEncoder")
302342
super().__init__(*args, **kwargs)
303343

304344
# Enhanced projection with residual pathways
305345
# With residual transformation
306346
# φ_m: ℝ^H → ℝ^H
347+
logger.debug("Creating residual feature projections")
307348
self.feature_projections = nn.ModuleDict({
308349
modality: nn.Sequential(
309350
nn.LayerNorm(self.hidden_dim),
@@ -322,6 +363,9 @@ def forward(
322363
mask: Optional[torch.Tensor] = None
323364
) -> torch.Tensor:
324365

366+
logger.info("Starting DynamicResidualEncoder forward pass")
367+
logger.debug(f"Input modalities: {list(inputs.keys())}")
368+
325369
""" Forward implementation with residual pathways """
326370

327371
# Encoded features dictionary
@@ -333,35 +377,48 @@ def forward(
333377
if modality not in self.modality_encoders or x is None:
334378
continue
335379

380+
logger.debug(f"Processing {modality} with shape {x.shape}")
336381
encoded = self.modality_encoders[modality](x)
382+
logger.debug(f"Encoded shape: {encoded.shape}")
337383

338384
# Residual connection
339385
# x_m ⊕ R_m(x_m)
340386
projected = encoded + self.feature_projections[modality](encoded)
387+
logger.debug(f"Projected shape with residual: {projected.shape}")
341388
encoded_features[modality] = projected
342389

343390
if not encoded_features:
344-
raise ValueError("No valid features after encoding")
391+
error_msg = "No valid features after encoding"
392+
logger.error(error_msg)
393+
raise ValueError(error_msg)
345394

346395
# Gating with residual pathways
347396
# g_m: x_m ⊕ g(x_m)
348397
if self.use_gating:
398+
logger.debug("Applying gating with residual connections")
349399
gated_features = self.gating(encoded_features)
350400
for modality in encoded_features:
351401
gated_features[modality] = gated_features[modality] + encoded_features[modality]
352402
encoded_features = gated_features
353-
403+
logger.debug(f"After gating shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
404+
354405
# Attention with residual pathways
355406
# A_m: x_m ⊕ A(x_m)
356407
if self.use_cross_attention and len(encoded_features) > 1:
408+
logger.debug("Applying cross-attention with residual connections")
357409
attended_features = self.cross_attention(encoded_features, mask)
358410
for modality in encoded_features:
359411
attended_features[modality] = attended_features[modality] + encoded_features[modality]
360412
encoded_features = attended_features
361-
413+
logger.debug(f"After attention shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
414+
362415
# Final fusion and output generation
416+
logger.debug("Applying fusion module")
363417
fused_features = self.fusion_module(encoded_features, mask)
364418
fused_features = fused_features.repeat(1, self.sequence_length)
419+
logger.debug(f"Fused features shape: {fused_features.shape}")
420+
365421
output = self.final_block(fused_features)
366-
422+
logger.debug(f"Final output shape: {output.shape}")
423+
367424
return output

0 commit comments

Comments
 (0)