From 09b5ff96487632bc89cb884e119f74d80e685ee6 Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Tue, 7 Jan 2025 22:54:27 +0000 Subject: [PATCH] Fix --- .../multimodal/encoders/dynamic_encoder.py | 203 ++++++++++++++++++ .../encoders/test_dynamic_encoder.py | 106 +++++++++ 2 files changed, 309 insertions(+) create mode 100644 pvnet/models/multimodal/encoders/dynamic_encoder.py create mode 100644 tests/models/multimodal/encoders/test_dynamic_encoder.py diff --git a/pvnet/models/multimodal/encoders/dynamic_encoder.py b/pvnet/models/multimodal/encoders/dynamic_encoder.py new file mode 100644 index 00000000..4fa8898a --- /dev/null +++ b/pvnet/models/multimodal/encoders/dynamic_encoder.py @@ -0,0 +1,203 @@ +# dynamic_encoder.py + +""" Dynamic fusion encoder implementation for multimodal learning """ + + +from typing import Dict, Optional, List, Union +import torch +from torch import nn + +from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder +from pvnet.models.multimodal.fusion_blocks import DynamicFusionModule, ModalityGating +from pvnet.models.multimodal.attention_blocks import CrossModalAttention, SelfAttention +from pvnet.models.multimodal.encoders.encoders3d import DefaultPVNet2 + + +class PVEncoder(nn.Module): + """ Simplified PV encoder - maintains sequence dimension """ + + def __init__(self, sequence_length: int, num_sites: int, out_features: int): + super().__init__() + self.sequence_length = sequence_length + self.num_sites = num_sites + self.out_features = out_features + + # Process each timestep independently + self.encoder = nn.Sequential( + nn.Linear(num_sites, out_features), + nn.LayerNorm(out_features), + nn.ReLU(), + nn.Dropout(0.1) + ) + + def forward(self, x): + # x: [batch_size, sequence_length, num_sites] + batch_size = x.shape[0] + # Process each timestep + out = [] + for t in range(self.sequence_length): + out.append(self.encoder(x[:, t])) + # Stack along sequence dimension + return torch.stack(out, dim=1) # [batch_size, sequence_length, out_features] + + +class DynamicFusionEncoder(AbstractNWPSatelliteEncoder): + + """Encoder that implements dynamic fusion of satellite/NWP data streams""" + + def __init__( + self, + sequence_length: int, + image_size_pixels: int, + modality_channels: Dict[str, int], + out_features: int, + modality_encoders: Dict[str, dict], + cross_attention: Dict, + modality_gating: Dict, + dynamic_fusion: Dict, + hidden_dim: int = 256, + fc_features: int = 128, + num_heads: int = 8, + dropout: float = 0.1, + use_gating: bool = True, + use_cross_attention: bool = True + ): + """Dynamic fusion encoder for multimodal satellite/NWP data.""" + super().__init__( + sequence_length=sequence_length, + image_size_pixels=image_size_pixels, + in_channels=sum(modality_channels.values()), + out_features=out_features + ) + + self.modalities = list(modality_channels.keys()) + self.hidden_dim = hidden_dim + self.sequence_length = sequence_length + + # Initialize modality-specific encoders + self.modality_encoders = nn.ModuleDict() + for modality, config in modality_encoders.items(): + config = config.copy() + if 'nwp' in modality or 'sat' in modality: + encoder = DefaultPVNet2( + sequence_length=sequence_length, + image_size_pixels=config.get('image_size_pixels', image_size_pixels), + in_channels=modality_channels[modality], + out_features=config.get('out_features', hidden_dim), + number_of_conv3d_layers=config.get('number_of_conv3d_layers', 4), + conv3d_channels=config.get('conv3d_channels', 32), + batch_norm=config.get('batch_norm', True), + fc_dropout=config.get('fc_dropout', 0.2) + ) + + self.modality_encoders[modality] = nn.Sequential( + encoder, + nn.Unflatten(1, (sequence_length, hidden_dim//sequence_length)) + ) + + elif modality == 'pv': + self.modality_encoders[modality] = PVEncoder( + sequence_length=sequence_length, + num_sites=config['num_sites'], + out_features=hidden_dim + ) + + # Feature projections + self.feature_projections = nn.ModuleDict({ + modality: nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + nn.Dropout(dropout) + ) + for modality in modality_channels.keys() + }) + + # Optional modality gating + self.use_gating = use_gating + if use_gating: + gating_config = modality_gating.copy() + gating_config['feature_dims'] = { + mod: hidden_dim for mod in modality_channels.keys() + } + self.gating = ModalityGating(**gating_config) + + # Optional cross-modal attention + self.use_cross_attention = use_cross_attention + if use_cross_attention: + attention_config = cross_attention.copy() + attention_config['embed_dim'] = hidden_dim + self.cross_attention = CrossModalAttention(**attention_config) + + # Dynamic fusion module + fusion_config = dynamic_fusion.copy() + fusion_config['feature_dims'] = { + mod: hidden_dim for mod in modality_channels.keys() + } + fusion_config['hidden_dim'] = hidden_dim + self.fusion_module = DynamicFusionModule(**fusion_config) + + # Final output projection + self.final_block = nn.Sequential( + nn.Linear(hidden_dim * sequence_length, fc_features), + nn.ELU(), + nn.Linear(fc_features, out_features), + nn.ELU(), + ) + + def forward( + self, + inputs: Dict[str, torch.Tensor], + mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Forward pass of the dynamic fusion encoder""" + # Initial encoding of each modality + encoded_features = {} + for modality, x in inputs.items(): + if modality not in self.modality_encoders: + continue + + # Apply modality-specific encoder + # Output shape: [batch_size, sequence_length, hidden_dim] + encoded_features[modality] = self.modality_encoders[modality](x) + + if not encoded_features: + raise ValueError("No valid features found in inputs") + + # Apply modality gating if enabled + if self.use_gating: + encoded_features = self.gating(encoded_features) + + # Apply cross-modal attention if enabled and more than one modality + if self.use_cross_attention and len(encoded_features) > 1: + encoded_features = self.cross_attention(encoded_features, mask) + + # Apply dynamic fusion + fused_features = self.fusion_module(encoded_features, mask) # [batch, sequence, hidden] + + # Reshape and apply final projection + batch_size = fused_features.size(0) + fused_features = fused_features.reshape(batch_size, -1) # Flatten sequence dimension + output = self.final_block(fused_features) + + return output + + +class DynamicResidualEncoder(DynamicFusionEncoder): + """Dynamic fusion encoder with residual connections""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Override feature projections to include residual connections + self.feature_projections = nn.ModuleDict({ + modality: nn.Sequential( + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim), + nn.ReLU(), + nn.Dropout(kwargs.get('dropout', 0.1)), + nn.Linear(self.hidden_dim, self.hidden_dim), + nn.LayerNorm(self.hidden_dim) + ) + for modality in kwargs['modality_channels'].keys() + }) \ No newline at end of file diff --git a/tests/models/multimodal/encoders/test_dynamic_encoder.py b/tests/models/multimodal/encoders/test_dynamic_encoder.py new file mode 100644 index 00000000..b5c86e1a --- /dev/null +++ b/tests/models/multimodal/encoders/test_dynamic_encoder.py @@ -0,0 +1,106 @@ +import pytest +import torch +from typing import Dict + +from pvnet.models.multimodal.encoders.dynamic_encoder import DynamicFusionEncoder + +@pytest.fixture +def minimal_config(): + """Minimal configuration for testing basic functionality""" + sequence_length = 12 + hidden_dim = 60 # Chosen so it divides evenly by sequence_length (60/12 = 5) + + # Important: feature_dim needs to match between modalities + feature_dim = hidden_dim // sequence_length # This is 5 + + return { + 'sequence_length': sequence_length, + 'image_size_pixels': 24, + 'modality_channels': { + 'sat': 2, + 'pv': 10 + }, + 'out_features': 32, + 'hidden_dim': hidden_dim, + 'fc_features': 32, + 'modality_encoders': { + 'sat': { + 'image_size_pixels': 24, + 'out_features': feature_dim * sequence_length, # 60 + 'number_of_conv3d_layers': 2, + 'conv3d_channels': 16, + 'batch_norm': True, + 'fc_dropout': 0.1 + }, + 'pv': { + 'num_sites': 10, + 'out_features': feature_dim # 5 - this ensures proper dimension + } + }, + 'cross_attention': { + 'embed_dim': hidden_dim, + 'num_heads': 4, + 'dropout': 0.1, + 'num_modalities': 2 + }, + 'modality_gating': { + 'feature_dims': { + 'sat': hidden_dim, + 'pv': hidden_dim + }, + 'hidden_dim': hidden_dim, + 'dropout': 0.1 + }, + 'dynamic_fusion': { + 'feature_dims': { + 'sat': hidden_dim, + 'pv': hidden_dim + }, + 'hidden_dim': hidden_dim, + 'num_heads': 4, + 'dropout': 0.1, + 'fusion_method': 'weighted_sum', + 'use_residual': True + } + } + +@pytest.fixture +def minimal_inputs(minimal_config): + """Generate minimal test inputs""" + batch_size = 2 + sequence_length = minimal_config['sequence_length'] + + return { + 'sat': torch.randn(batch_size, 2, sequence_length, 24, 24), + 'pv': torch.randn(batch_size, sequence_length, 10) + } + +def test_batch_sizes(self, minimal_config, minimal_inputs, batch_size): + """Test different batch sizes""" + encoder = DynamicFusionEncoder( + sequence_length=minimal_config['sequence_length'], + image_size_pixels=minimal_config['image_size_pixels'], + modality_channels=minimal_config['modality_channels'], + out_features=minimal_config['out_features'], + modality_encoders=minimal_config['modality_encoders'], + cross_attention=minimal_config['cross_attention'], + modality_gating=minimal_config['modality_gating'], + dynamic_fusion=minimal_config['dynamic_fusion'], + hidden_dim=minimal_config['hidden_dim'], + fc_features=minimal_config['fc_features'] + ) + + # Adjust input batch sizes - fixed repeat logic + adjusted_inputs = {} + for k, v in minimal_inputs.items(): + if batch_size < v.size(0): + adjusted_inputs[k] = v[:batch_size] + else: + repeat_factor = batch_size // v.size(0) + adjusted_inputs[k] = v.repeat(repeat_factor, *[1]*(len(v.shape)-1)) + + with torch.no_grad(): + output = encoder(adjusted_inputs) + + assert output.shape == (batch_size, minimal_config['out_features']) + assert not torch.isnan(output).any() \ No newline at end of file