From b6fabde9809f3a7965e328a9087c212a0b6d19a6 Mon Sep 17 00:00:00 2001 From: Felix Peretz Date: Sat, 18 Jan 2025 09:24:20 +0000 Subject: [PATCH] Fusion blocks update --- pvnet/models/multimodal/fusion_blocks.py | 101 +++++++++++++---------- 1 file changed, 57 insertions(+), 44 deletions(-) diff --git a/pvnet/models/multimodal/fusion_blocks.py b/pvnet/models/multimodal/fusion_blocks.py index 71146302..51c68752 100644 --- a/pvnet/models/multimodal/fusion_blocks.py +++ b/pvnet/models/multimodal/fusion_blocks.py @@ -7,7 +7,7 @@ Aformentioned fusion blocks apply dynamic attention, weighted combinations and / or gating mechanisms for feature learning -Summararily, this enables dynamic feature learning through attention based weighting and modality specific gating +Summararily - enables dynamic feature learning through attention based weighting and modality specific gating """ @@ -24,6 +24,8 @@ class AbstractFusionBlock(nn.Module, ABC): """ Abstract fusion base class definition """ # Forward pass + # Function mapping + # F: X → Y in fusion space @abstractmethod def forward( self, @@ -36,8 +38,12 @@ class DynamicFusionModule(AbstractFusionBlock): """ Implementation of dynamic multimodal fusion through cross attention and weighted combination """ - # Input dimension specified and common embedding dimension - # Quantity of attention heads also specified + # Define feature dimensions + # d_i ∈ ℝ^n + # Shared latent space + # ℝ^h + # Attention mechanisms + # A_i: ℝ^d → ℝ^{d/h} def __init__( self, feature_dims: Dict[str, int], @@ -61,6 +67,7 @@ def __init__( raise ValueError(f"Invalid fusion method: {fusion_method}") # Projections - modality specific + # φ_m: ℝ^{d_m} → ℝ^h for m ∈ M self.projections = nn.ModuleDict({ name: nn.Sequential( nn.Linear(dim, hidden_dim), @@ -73,13 +80,15 @@ def __init__( }) # Attention - cross modal + # ℝ^{d_m} → ℝ^h self.cross_attention = MultiheadAttention( embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout ) - # Weight computation network definition + # Weight computation network definition - dynamic + # W: ℝ^h → [0,1] self.weight_network = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.ReLU(), @@ -89,6 +98,7 @@ def __init__( ) # Optional concat projection + # P: ℝ^{h|M|} → ℝ^h if fusion_method == "concat": self.output_projection = nn.Sequential( nn.Linear(hidden_dim * len(feature_dims), hidden_dim), @@ -99,48 +109,32 @@ def __init__( if use_residual: self.layer_norm = nn.LayerNorm(hidden_dim) - - # def _validate_features(self, features: Dict[str, torch.Tensor]) -> None: - # """ Validates input feature dimensions and sequence lengths """ - - # if not features: - # raise ValueError("Empty features dict") - - # seq_length = None - # for name, feat in features.items(): - # if feat is None: - # raise ValueError(f"None tensor for modality: {name}") - - # if seq_length is None: - # seq_length = feat.size(1) - # elif feat.size(1) != seq_length: - # raise ValueError("All modalities must have same sequence length") def _validate_features(self, features: Dict[str, torch.Tensor]) -> None: """ Validates input feature dimensions and sequence lengths """ - # Handle case where features might be a single tensor or empty + # Validate feature space dimensionality d_m + # Validate sequence length L if not isinstance(features, dict) or not features: if isinstance(features, torch.Tensor): return # Skip validation for single tensor raise ValueError("Empty features dict") - # Collect feature lengths for features with 2D+ tensors + # Validate temporal dimensions L_m across modalities multi_dim_features = {} for name, feat in features.items(): if feat is None: raise ValueError(f"None tensor for modality: {name}") - # Only consider features with more than 1 dimension if feat.ndim > 1: multi_dim_features[name] = feat.size(1) - # If more than one unique length, raise an error + # Verification step + # L_i = L_j ∀i,j ∈ M feature_lengths = set(multi_dim_features.values()) if len(feature_lengths) > 1: raise ValueError(f"All modalities must have same sequence length. Current lengths: {multi_dim_features}") - def compute_modality_weights( self, features: torch.Tensor, @@ -159,6 +153,7 @@ def compute_modality_weights( weights = weights.masked_fill(~modality_mask.unsqueeze(-1), 0.0) # Normalisation of weights + # α_m = w_m / Σ_j w_j weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9) return weights @@ -175,7 +170,8 @@ def forward( batch_size = next(iter(features.values())).size(0) seq_len = next(iter(features.values())).size(1) - # Project each modality + # Apply modality-specific embeddings + # φ_m(x_m) projected_features = [] for name, feat in features.items(): if self.feature_dims[name] > 0: @@ -187,39 +183,47 @@ def forward( if not projected_features: raise ValueError("No valid features after projection") - # Stack features + # Tensor product of embedded features + # ⊗_{m∈M} φ_m(x_m) feature_stack = torch.stack(projected_features, dim=1) - + # Apply cross attention + # A(⊗_{m∈M} φ_m(x_m)) attended_features = [] for i in range(feature_stack.size(1)): query = feature_stack[:, i] - key_value = feature_stack[:, [j for j in range(feature_stack.size(1)) if j != i]] - if key_value.size(1) > 0: + if feature_stack.size(1) > 1: + + # Case |M| > 1 + # Apply cross-modal attention A_c + key_value = feature_stack[:, [j for j in range(feature_stack.size(1)) if j != i]] attended = self.cross_attention(query, key_value.reshape(-1, seq_len, self.hidden_dim), key_value.reshape(-1, seq_len, self.hidden_dim)) - attended_features.append(attended) else: - attended_features.append(query) - - # Average across modalities + + # Case |M| = 1 + # Apply self-attention A_s + attended = self.cross_attention(query, query, query) + attended_features.append(attended) + + # Compute mean representation + # μ = 1/|M| Σ_{m∈M} A_m attended_features = torch.stack(attended_features, dim=1) attended_avg = attended_features.mean(dim=1) - # Mask attended features to match + # Apply attention mask + # M ∈ {0,1}^{B×L} if modality_mask is not None: - # Create binary mask matching sequence length seq_mask = torch.zeros((batch_size, seq_len), device=attended_avg.device).bool() - seq_mask[:, :modality_mask.size(1)] = modality_mask - - # Compute weights on masked features + seq_mask[:, :modality_mask.size(1)] = modality_mask weights = self.compute_modality_weights(attended_avg, seq_mask) weights = weights.unsqueeze(1).expand(-1, attended_features.size(1), -1, 1) else: weights = self.compute_modality_weights(attended_avg) weights = weights.unsqueeze(1).expand(-1, attended_features.size(1), -1, 1) - # Application of weighted features + # Apply dynamic modality weights + # w_m ∈ [0,1] weighted_features = attended_features * weights if self.fusion_method == "weighted_sum": @@ -229,11 +233,14 @@ def forward( fused = self.output_projection(concat) # Application of residual + # r(x) = LayerNorm(x + μ(x)) if self.use_residual: residual = feature_stack.mean(dim=1) fused = self.layer_norm(fused + residual) # Collapse sequence dimension for output + # Temporal pooling + # τ: ℝ^{B×L×h} → ℝ^{B×h} fused = fused.mean(dim=1) return fused @@ -243,6 +250,10 @@ class ModalityGating(AbstractFusionBlock): """ Implementation of modality specific gating mechanism """ # Input and hidden dimension definition + # Input spaces + # X_m ∈ ℝ^{d_m} + # Hidden space + # H ∈ ℝ^h def __init__( self, feature_dims: Dict[str, int], @@ -257,7 +268,8 @@ def __init__( self.feature_dims = feature_dims self.hidden_dim = hidden_dim - # Define gate networks for each modality + # Define gate networks for each modality - functions + # g_m: ℝ^{d_m} → [0,1] self.gate_networks = nn.ModuleDict({ name: nn.Sequential( nn.Linear(dim, hidden_dim), @@ -279,7 +291,6 @@ def _validate_features(self, features: Dict[str, torch.Tensor]) -> None: if feat is None: raise ValueError(f"None tensor for modality: {name}") - def forward( self, features: Dict[str, torch.Tensor] @@ -294,12 +305,14 @@ def forward( if feat is not None and name in self.gate_networks: batch_size, seq_len, feat_dim = feat.shape - # Gate computation sequence + # Compute gating activation + # α_m = σ(g_m(x_m)) flat_feat = feat.reshape(-1, feat_dim) gate = self.gate_networks[name](flat_feat) gate = gate.reshape(batch_size, seq_len, 1) - # Application of gating + # Apply multiplicative gating + # y_m = x_m ⊙ α_m gated_features[name] = feat * gate return gated_features