diff --git a/pvnet/models/multimodal/encoders/dynamic_encoder.py b/pvnet/models/multimodal/encoders/dynamic_encoder.py index 4111c019..f35c19f8 100644 --- a/pvnet/models/multimodal/encoders/dynamic_encoder.py +++ b/pvnet/models/multimodal/encoders/dynamic_encoder.py @@ -9,6 +9,7 @@ from typing import Dict, Optional, List, Union import torch from torch import nn +import logging from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder from pvnet.models.multimodal.fusion_blocks import DynamicFusionModule, ModalityGating @@ -16,15 +17,22 @@ from pvnet.models.multimodal.encoders.encoders3d import DefaultPVNet2 +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger('dynamic_encoder') + + # Attention head compatibility function def get_compatible_heads(dim: int, target_heads: int) -> int: """ Calculate largest compatible number of heads <= target_heads """ + logger.debug(f"Finding compatible heads for dim={dim}, target_heads={target_heads}") + # Iterative reduction # Obtain maximum divisible number of heads # h ∈ ℕ : h ≤ target_heads ∧ dim mod h = 0 for h in range(min(target_heads, dim), 0, -1): if dim % h == 0: + logger.debug(f"Selected compatible head count: {h}") return h return 1 @@ -35,6 +43,7 @@ class PVEncoder(nn.Module): def __init__(self, sequence_length: int, num_sites: int, out_features: int): super().__init__() + logger.info(f"Initialising PVEncoder with sequence_length={sequence_length}, num_sites={num_sites}") # Temporal and spatial configuration parameters # L: sequence length @@ -46,6 +55,7 @@ def __init__(self, sequence_length: int, num_sites: int, out_features: int): # Basic feature extraction network # φ: ℝ^M → ℝ^N # Linear Transformation → Layer Normalization → ReLU → Dropout + logger.debug("Creating encoder network") self.encoder = nn.Sequential( nn.Linear(num_sites, out_features), nn.LayerNorm(out_features), @@ -57,13 +67,17 @@ def forward(self, x): # Sequential processing - maintain temporal order # x ∈ ℝ^{B×L×M} → out ∈ ℝ^{B×L×N} + logger.debug(f"PVEncoder input shape: {x.shape}") batch_size = x.shape[0] out = [] for t in range(self.sequence_length): + logger.debug(f"Processing timestep {t}") out.append(self.encoder(x[:, t]))\ # Reshape maintaining sequence dimension - return torch.stack(out, dim=1) + result = torch.stack(out, dim=1) + logger.debug(f"PVEncoder output shape: {result.shape}") + return result # Primary fusion encoder implementation @@ -87,6 +101,8 @@ def __init__( ): """ Dynamic fusion encoder initialisation """ + logger.info(f"Initialising DynamicFusionEncoder with sequence_length={sequence_length}, out_features={out_features}") + super().__init__( sequence_length=sequence_length, image_size_pixels=image_size_pixels, @@ -96,10 +112,12 @@ def __init__( # Dimension validation and compatibility # Adjust hidden dimension to be divisible by sequence length - # H = feature_dim × sequence_length + # H = feature_dim × sequence_length + logger.debug(f"Initial hidden_dim: {hidden_dim}") if hidden_dim % sequence_length != 0: feature_dim = ((hidden_dim + sequence_length - 1) // sequence_length) hidden_dim = feature_dim * sequence_length + logger.debug(f"Adjusted hidden_dim to {hidden_dim} for sequence length compatibility") else: feature_dim = hidden_dim // sequence_length @@ -108,15 +126,18 @@ def __init__( # h ∈ ℕ : h ≤ num_heads ∧ feature_dim mod h = 0 attention_heads = cross_attention.get('num_heads', num_heads) attention_heads = get_compatible_heads(feature_dim, attention_heads) - + logger.debug(f"Using {attention_heads} attention heads") + # Dimension adjustment for attention mechanism # Ensure feature dimension is compatible with attention heads if feature_dim < attention_heads: feature_dim = attention_heads hidden_dim = feature_dim * sequence_length + logger.debug(f"Adjusted dimensions - feature_dim: {feature_dim}, hidden_dim: {hidden_dim}") elif feature_dim % attention_heads != 0: feature_dim = ((feature_dim + attention_heads - 1) // attention_heads) * attention_heads hidden_dim = feature_dim * sequence_length + logger.debug(f"Adjusted for attention compatibility - feature_dim: {feature_dim}, hidden_dim: {hidden_dim}") # Architecture dimensions self.feature_dim = feature_dim @@ -129,6 +150,7 @@ def __init__( dynamic_fusion['num_heads'] = attention_heads # Modality specific encoder instantiation + logger.debug("Creating modality encoders") self.modality_encoders = nn.ModuleDict() for modality, config in modality_encoders.items(): config = config.copy() @@ -161,6 +183,7 @@ def __init__( ) # Feature transformation layers + logger.debug("Creating feature projections") self.feature_projections = nn.ModuleDict({ modality: nn.Sequential( nn.LayerNorm(feature_dim), @@ -174,6 +197,7 @@ def __init__( # Modality gating mechanism self.use_gating = use_gating if use_gating: + logger.debug("Initialising gating mechanism") gating_config = modality_gating.copy() gating_config.update({ 'feature_dims': {mod: feature_dim for mod in modality_channels.keys()}, @@ -184,6 +208,7 @@ def __init__( # Cross modal attention mechanism self.use_cross_attention = use_cross_attention and len(modality_channels) > 1 if self.use_cross_attention: + logger.debug("Initialising cross attention") attention_config = cross_attention.copy() attention_config.update({ 'embed_dim': feature_dim, @@ -194,6 +219,7 @@ def __init__( self.cross_attention = CrossModalAttention(**attention_config) # Dynamic fusion implementation + logger.debug("Initialising fusion module") fusion_config = dynamic_fusion.copy() fusion_config.update({ 'feature_dims': {mod: feature_dim for mod in modality_channels.keys()}, @@ -204,6 +230,7 @@ def __init__( self.fusion_module = DynamicFusionModule(**fusion_config) # Output network definition + logger.debug("Creating final output block") self.final_block = nn.Sequential( nn.Linear(hidden_dim, fc_features), nn.LayerNorm(fc_features), @@ -219,6 +246,9 @@ def forward( mask: Optional[torch.Tensor] = None ) -> torch.Tensor: + logger.info("Starting DynamicFusionEncoder forward pass") + logger.debug(f"Input modalities: {list(inputs.keys())}") + # Encoded features dictionary # M ∈ {x_m | m ∈ Modalities} encoded_features = {} @@ -230,8 +260,9 @@ def forward( continue # Feature extraction and projection + logger.debug(f"Encoding {modality} input of shape {x.shape}") encoded = self.modality_encoders[modality](x) - print(f"Encoded {modality} shape: {encoded.shape}") + logger.debug(f"Encoded {modality} shape: {encoded.shape}") # Temporal projection across sequence # π: ℝ^{B×L×D} → ℝ^{B×L×D} @@ -239,40 +270,44 @@ def forward( self.feature_projections[modality](encoded[:, t]) for t in range(self.sequence_length) ], dim=1) - print(f"Projected {modality} shape: {projected.shape}") + logger.debug(f"Projected {modality} shape: {projected.shape}") encoded_features[modality] = projected # Validation of encoded feature space # |M| > 0 if not encoded_features: - raise ValueError("No valid features after encoding") + error_msg = "No valid features after encoding" + logger.error(error_msg) + raise ValueError(error_msg) # Apply modality interaction mechanisms if self.use_gating: # g: M → M̂ # Adaptive feature transformation with learned gates + logger.debug("Applying modality gating") encoded_features = self.gating(encoded_features) - print(f"After gating, encoded_features shapes: {[encoded_features[mod].shape for mod in encoded_features]}") + logger.debug(f"After gating shapes: {[encoded_features[mod].shape for mod in encoded_features]}") # Cross-modal attention mechanism if self.use_cross_attention: if len(encoded_features) > 1: - + logger.debug("Applying cross-modal attention") # Multi-modal cross attention encoded_features = self.cross_attention(encoded_features, mask) else: - + logger.debug("Single modality: skipping cross-attention") # For single modality, apply self-attention instead for key in encoded_features: encoded_features[key] = encoded_features[key] # Identity mapping - print(f"After cross-modal attention - encoded_features shapes: {[encoded_features[mod].shape for mod in encoded_features]}") + logger.debug(f"After attention shapes: {[encoded_features[mod].shape for mod in encoded_features]}") # Feature fusion and output generation + logger.debug("Applying fusion module") fused_features = self.fusion_module(encoded_features, mask) - print(f"Fused features shape: {fused_features.shape}") + logger.debug(f"Fused features shape: {fused_features.shape}") # Ensure input to final_block matches hidden_dim # Ensure z ∈ ℝ^{B×H}, H: hidden dimension @@ -280,18 +315,21 @@ def forward( # Repeat the features to match the expected hidden dimension if fused_features.size(1) != self.hidden_dim: + logger.debug("Adjusting fused features dimension") fused_features = fused_features.repeat(1, self.hidden_dim // fused_features.size(1)) # Precision projection if dimension mismatch persists # π_H: ℝ^k → ℝ^H if fused_features.size(1) != self.hidden_dim: + logger.debug("Creating precision projection") projection = nn.Linear(fused_features.size(1), self.hidden_dim).to(fused_features.device) fused_features = projection(fused_features) # Final output generation # ψ: ℝ^H → ℝ^M, M: output features output = self.final_block(fused_features) - + logger.debug(f"Final output shape: {output.shape}") + return output @@ -299,11 +337,14 @@ class DynamicResidualEncoder(DynamicFusionEncoder): """ Dynamic fusion implementation with residual connectivity """ def __init__(self, *args, **kwargs): + + logger.info("Initialising DynamicResidualEncoder") super().__init__(*args, **kwargs) # Enhanced projection with residual pathways # With residual transformation # φ_m: ℝ^H → ℝ^H + logger.debug("Creating residual feature projections") self.feature_projections = nn.ModuleDict({ modality: nn.Sequential( nn.LayerNorm(self.hidden_dim), @@ -322,6 +363,9 @@ def forward( mask: Optional[torch.Tensor] = None ) -> torch.Tensor: + logger.info("Starting DynamicResidualEncoder forward pass") + logger.debug(f"Input modalities: {list(inputs.keys())}") + """ Forward implementation with residual pathways """ # Encoded features dictionary @@ -333,35 +377,48 @@ def forward( if modality not in self.modality_encoders or x is None: continue + logger.debug(f"Processing {modality} with shape {x.shape}") encoded = self.modality_encoders[modality](x) + logger.debug(f"Encoded shape: {encoded.shape}") # Residual connection # x_m ⊕ R_m(x_m) projected = encoded + self.feature_projections[modality](encoded) + logger.debug(f"Projected shape with residual: {projected.shape}") encoded_features[modality] = projected if not encoded_features: - raise ValueError("No valid features after encoding") + error_msg = "No valid features after encoding" + logger.error(error_msg) + raise ValueError(error_msg) # Gating with residual pathways # g_m: x_m ⊕ g(x_m) if self.use_gating: + logger.debug("Applying gating with residual connections") gated_features = self.gating(encoded_features) for modality in encoded_features: gated_features[modality] = gated_features[modality] + encoded_features[modality] encoded_features = gated_features - + logger.debug(f"After gating shapes: {[encoded_features[mod].shape for mod in encoded_features]}") + # Attention with residual pathways # A_m: x_m ⊕ A(x_m) if self.use_cross_attention and len(encoded_features) > 1: + logger.debug("Applying cross-attention with residual connections") attended_features = self.cross_attention(encoded_features, mask) for modality in encoded_features: attended_features[modality] = attended_features[modality] + encoded_features[modality] encoded_features = attended_features - + logger.debug(f"After attention shapes: {[encoded_features[mod].shape for mod in encoded_features]}") + # Final fusion and output generation + logger.debug("Applying fusion module") fused_features = self.fusion_module(encoded_features, mask) fused_features = fused_features.repeat(1, self.sequence_length) + logger.debug(f"Fused features shape: {fused_features.shape}") + output = self.final_block(fused_features) - + logger.debug(f"Final output shape: {output.shape}") + return output