Skip to content

Commit

Permalink
Dynamic encoder update - debug
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Feb 14, 2025
1 parent 8cd6165 commit e644d6e
Showing 1 changed file with 73 additions and 16 deletions.
89 changes: 73 additions & 16 deletions pvnet/models/multimodal/encoders/dynamic_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,30 @@
from typing import Dict, Optional, List, Union
import torch
from torch import nn
import logging

from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder
from pvnet.models.multimodal.fusion_blocks import DynamicFusionModule, ModalityGating
from pvnet.models.multimodal.attention_blocks import CrossModalAttention, SelfAttention
from pvnet.models.multimodal.encoders.encoders3d import DefaultPVNet2


logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('dynamic_encoder')


# Attention head compatibility function
def get_compatible_heads(dim: int, target_heads: int) -> int:
""" Calculate largest compatible number of heads <= target_heads """

logger.debug(f"Finding compatible heads for dim={dim}, target_heads={target_heads}")

# Iterative reduction
# Obtain maximum divisible number of heads
# h ∈ ℕ : h ≤ target_heads ∧ dim mod h = 0
for h in range(min(target_heads, dim), 0, -1):
if dim % h == 0:
logger.debug(f"Selected compatible head count: {h}")
return h
return 1

Expand All @@ -35,6 +43,7 @@ class PVEncoder(nn.Module):

def __init__(self, sequence_length: int, num_sites: int, out_features: int):
super().__init__()
logger.info(f"Initialising PVEncoder with sequence_length={sequence_length}, num_sites={num_sites}")

# Temporal and spatial configuration parameters
# L: sequence length
Expand All @@ -46,6 +55,7 @@ def __init__(self, sequence_length: int, num_sites: int, out_features: int):
# Basic feature extraction network
# φ: ℝ^M → ℝ^N
# Linear Transformation → Layer Normalization → ReLU → Dropout
logger.debug("Creating encoder network")
self.encoder = nn.Sequential(
nn.Linear(num_sites, out_features),
nn.LayerNorm(out_features),
Expand All @@ -57,13 +67,17 @@ def forward(self, x):

# Sequential processing - maintain temporal order
# x ∈ ℝ^{B×L×M} → out ∈ ℝ^{B×L×N}
logger.debug(f"PVEncoder input shape: {x.shape}")
batch_size = x.shape[0]
out = []
for t in range(self.sequence_length):
logger.debug(f"Processing timestep {t}")
out.append(self.encoder(x[:, t]))\

# Reshape maintaining sequence dimension
return torch.stack(out, dim=1)
result = torch.stack(out, dim=1)
logger.debug(f"PVEncoder output shape: {result.shape}")
return result


# Primary fusion encoder implementation
Expand All @@ -87,6 +101,8 @@ def __init__(
):
""" Dynamic fusion encoder initialisation """

logger.info(f"Initialising DynamicFusionEncoder with sequence_length={sequence_length}, out_features={out_features}")

super().__init__(
sequence_length=sequence_length,
image_size_pixels=image_size_pixels,
Expand All @@ -96,10 +112,12 @@ def __init__(

# Dimension validation and compatibility
# Adjust hidden dimension to be divisible by sequence length
# H = feature_dim × sequence_length
# H = feature_dim × sequence_length
logger.debug(f"Initial hidden_dim: {hidden_dim}")
if hidden_dim % sequence_length != 0:
feature_dim = ((hidden_dim + sequence_length - 1) // sequence_length)
hidden_dim = feature_dim * sequence_length
logger.debug(f"Adjusted hidden_dim to {hidden_dim} for sequence length compatibility")
else:
feature_dim = hidden_dim // sequence_length

Expand All @@ -108,15 +126,18 @@ def __init__(
# h ∈ ℕ : h ≤ num_heads ∧ feature_dim mod h = 0
attention_heads = cross_attention.get('num_heads', num_heads)
attention_heads = get_compatible_heads(feature_dim, attention_heads)

logger.debug(f"Using {attention_heads} attention heads")

# Dimension adjustment for attention mechanism
# Ensure feature dimension is compatible with attention heads
if feature_dim < attention_heads:
feature_dim = attention_heads
hidden_dim = feature_dim * sequence_length
logger.debug(f"Adjusted dimensions - feature_dim: {feature_dim}, hidden_dim: {hidden_dim}")
elif feature_dim % attention_heads != 0:
feature_dim = ((feature_dim + attention_heads - 1) // attention_heads) * attention_heads
hidden_dim = feature_dim * sequence_length
logger.debug(f"Adjusted for attention compatibility - feature_dim: {feature_dim}, hidden_dim: {hidden_dim}")

# Architecture dimensions
self.feature_dim = feature_dim
Expand All @@ -129,6 +150,7 @@ def __init__(
dynamic_fusion['num_heads'] = attention_heads

# Modality specific encoder instantiation
logger.debug("Creating modality encoders")
self.modality_encoders = nn.ModuleDict()
for modality, config in modality_encoders.items():
config = config.copy()
Expand Down Expand Up @@ -161,6 +183,7 @@ def __init__(
)

# Feature transformation layers
logger.debug("Creating feature projections")
self.feature_projections = nn.ModuleDict({
modality: nn.Sequential(
nn.LayerNorm(feature_dim),
Expand All @@ -174,6 +197,7 @@ def __init__(
# Modality gating mechanism
self.use_gating = use_gating
if use_gating:
logger.debug("Initialising gating mechanism")
gating_config = modality_gating.copy()
gating_config.update({
'feature_dims': {mod: feature_dim for mod in modality_channels.keys()},
Expand All @@ -184,6 +208,7 @@ def __init__(
# Cross modal attention mechanism
self.use_cross_attention = use_cross_attention and len(modality_channels) > 1
if self.use_cross_attention:
logger.debug("Initialising cross attention")
attention_config = cross_attention.copy()
attention_config.update({
'embed_dim': feature_dim,
Expand All @@ -194,6 +219,7 @@ def __init__(
self.cross_attention = CrossModalAttention(**attention_config)

# Dynamic fusion implementation
logger.debug("Initialising fusion module")
fusion_config = dynamic_fusion.copy()
fusion_config.update({
'feature_dims': {mod: feature_dim for mod in modality_channels.keys()},
Expand All @@ -204,6 +230,7 @@ def __init__(
self.fusion_module = DynamicFusionModule(**fusion_config)

# Output network definition
logger.debug("Creating final output block")
self.final_block = nn.Sequential(
nn.Linear(hidden_dim, fc_features),
nn.LayerNorm(fc_features),
Expand All @@ -219,6 +246,9 @@ def forward(
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:

logger.info("Starting DynamicFusionEncoder forward pass")
logger.debug(f"Input modalities: {list(inputs.keys())}")

# Encoded features dictionary
# M ∈ {x_m | m ∈ Modalities}
encoded_features = {}
Expand All @@ -230,80 +260,91 @@ def forward(
continue

# Feature extraction and projection
logger.debug(f"Encoding {modality} input of shape {x.shape}")
encoded = self.modality_encoders[modality](x)
print(f"Encoded {modality} shape: {encoded.shape}")
logger.debug(f"Encoded {modality} shape: {encoded.shape}")

# Temporal projection across sequence
# π: ℝ^{B×L×D} → ℝ^{B×L×D}
projected = torch.stack([
self.feature_projections[modality](encoded[:, t])
for t in range(self.sequence_length)
], dim=1)
print(f"Projected {modality} shape: {projected.shape}")
logger.debug(f"Projected {modality} shape: {projected.shape}")

encoded_features[modality] = projected

# Validation of encoded feature space
# |M| > 0
if not encoded_features:
raise ValueError("No valid features after encoding")
error_msg = "No valid features after encoding"
logger.error(error_msg)
raise ValueError(error_msg)

# Apply modality interaction mechanisms
if self.use_gating:

# g: M → M̂
# Adaptive feature transformation with learned gates
logger.debug("Applying modality gating")
encoded_features = self.gating(encoded_features)
print(f"After gating, encoded_features shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
logger.debug(f"After gating shapes: {[encoded_features[mod].shape for mod in encoded_features]}")

# Cross-modal attention mechanism
if self.use_cross_attention:
if len(encoded_features) > 1:

logger.debug("Applying cross-modal attention")
# Multi-modal cross attention
encoded_features = self.cross_attention(encoded_features, mask)
else:

logger.debug("Single modality: skipping cross-attention")
# For single modality, apply self-attention instead
for key in encoded_features:
encoded_features[key] = encoded_features[key] # Identity mapping

print(f"After cross-modal attention - encoded_features shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
logger.debug(f"After attention shapes: {[encoded_features[mod].shape for mod in encoded_features]}")

# Feature fusion and output generation
logger.debug("Applying fusion module")
fused_features = self.fusion_module(encoded_features, mask)
print(f"Fused features shape: {fused_features.shape}")
logger.debug(f"Fused features shape: {fused_features.shape}")

# Ensure input to final_block matches hidden_dim
# Ensure z ∈ ℝ^{B×H}, H: hidden dimension
batch_size = fused_features.size(0)

# Repeat the features to match the expected hidden dimension
if fused_features.size(1) != self.hidden_dim:
logger.debug("Adjusting fused features dimension")
fused_features = fused_features.repeat(1, self.hidden_dim // fused_features.size(1))

# Precision projection if dimension mismatch persists
# π_H: ℝ^k → ℝ^H
if fused_features.size(1) != self.hidden_dim:
logger.debug("Creating precision projection")
projection = nn.Linear(fused_features.size(1), self.hidden_dim).to(fused_features.device)
fused_features = projection(fused_features)

# Final output generation
# ψ: ℝ^H → ℝ^M, M: output features
output = self.final_block(fused_features)

logger.debug(f"Final output shape: {output.shape}")

return output


class DynamicResidualEncoder(DynamicFusionEncoder):
""" Dynamic fusion implementation with residual connectivity """

def __init__(self, *args, **kwargs):

logger.info("Initialising DynamicResidualEncoder")
super().__init__(*args, **kwargs)

# Enhanced projection with residual pathways
# With residual transformation
# φ_m: ℝ^H → ℝ^H
logger.debug("Creating residual feature projections")
self.feature_projections = nn.ModuleDict({
modality: nn.Sequential(
nn.LayerNorm(self.hidden_dim),
Expand All @@ -322,6 +363,9 @@ def forward(
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:

logger.info("Starting DynamicResidualEncoder forward pass")
logger.debug(f"Input modalities: {list(inputs.keys())}")

""" Forward implementation with residual pathways """

# Encoded features dictionary
Expand All @@ -333,35 +377,48 @@ def forward(
if modality not in self.modality_encoders or x is None:
continue

logger.debug(f"Processing {modality} with shape {x.shape}")
encoded = self.modality_encoders[modality](x)
logger.debug(f"Encoded shape: {encoded.shape}")

# Residual connection
# x_m ⊕ R_m(x_m)
projected = encoded + self.feature_projections[modality](encoded)
logger.debug(f"Projected shape with residual: {projected.shape}")
encoded_features[modality] = projected

if not encoded_features:
raise ValueError("No valid features after encoding")
error_msg = "No valid features after encoding"
logger.error(error_msg)
raise ValueError(error_msg)

# Gating with residual pathways
# g_m: x_m ⊕ g(x_m)
if self.use_gating:
logger.debug("Applying gating with residual connections")
gated_features = self.gating(encoded_features)
for modality in encoded_features:
gated_features[modality] = gated_features[modality] + encoded_features[modality]
encoded_features = gated_features

logger.debug(f"After gating shapes: {[encoded_features[mod].shape for mod in encoded_features]}")

# Attention with residual pathways
# A_m: x_m ⊕ A(x_m)
if self.use_cross_attention and len(encoded_features) > 1:
logger.debug("Applying cross-attention with residual connections")
attended_features = self.cross_attention(encoded_features, mask)
for modality in encoded_features:
attended_features[modality] = attended_features[modality] + encoded_features[modality]
encoded_features = attended_features

logger.debug(f"After attention shapes: {[encoded_features[mod].shape for mod in encoded_features]}")

# Final fusion and output generation
logger.debug("Applying fusion module")
fused_features = self.fusion_module(encoded_features, mask)
fused_features = fused_features.repeat(1, self.sequence_length)
logger.debug(f"Fused features shape: {fused_features.shape}")

output = self.final_block(fused_features)

logger.debug(f"Final output shape: {output.shape}")

return output

0 comments on commit e644d6e

Please sign in to comment.