Skip to content

Commit

Permalink
Fusion blocks update
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Jan 18, 2025
1 parent 820d773 commit b6fabde
Showing 1 changed file with 57 additions and 44 deletions.
101 changes: 57 additions & 44 deletions pvnet/models/multimodal/fusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Aformentioned fusion blocks apply dynamic attention, weighted combinations and / or gating mechanisms for feature learning
Summararily, this enables dynamic feature learning through attention based weighting and modality specific gating
Summararily - enables dynamic feature learning through attention based weighting and modality specific gating
"""


Expand All @@ -24,6 +24,8 @@ class AbstractFusionBlock(nn.Module, ABC):
""" Abstract fusion base class definition """

# Forward pass
# Function mapping
# F: X → Y in fusion space
@abstractmethod
def forward(
self,
Expand All @@ -36,8 +38,12 @@ class DynamicFusionModule(AbstractFusionBlock):

""" Implementation of dynamic multimodal fusion through cross attention and weighted combination """

# Input dimension specified and common embedding dimension
# Quantity of attention heads also specified
# Define feature dimensions
# d_i ∈ ℝ^n
# Shared latent space
# ℝ^h
# Attention mechanisms
# A_i: ℝ^d → ℝ^{d/h}
def __init__(
self,
feature_dims: Dict[str, int],
Expand All @@ -61,6 +67,7 @@ def __init__(
raise ValueError(f"Invalid fusion method: {fusion_method}")

# Projections - modality specific
# φ_m: ℝ^{d_m} → ℝ^h for m ∈ M
self.projections = nn.ModuleDict({
name: nn.Sequential(
nn.Linear(dim, hidden_dim),
Expand All @@ -73,13 +80,15 @@ def __init__(
})

# Attention - cross modal
# ℝ^{d_m} → ℝ^h
self.cross_attention = MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout
)

# Weight computation network definition
# Weight computation network definition - dynamic
# W: ℝ^h → [0,1]
self.weight_network = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
Expand All @@ -89,6 +98,7 @@ def __init__(
)

# Optional concat projection
# P: ℝ^{h|M|} → ℝ^h
if fusion_method == "concat":
self.output_projection = nn.Sequential(
nn.Linear(hidden_dim * len(feature_dims), hidden_dim),
Expand All @@ -99,48 +109,32 @@ def __init__(

if use_residual:
self.layer_norm = nn.LayerNorm(hidden_dim)

# def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
# """ Validates input feature dimensions and sequence lengths """

# if not features:
# raise ValueError("Empty features dict")

# seq_length = None
# for name, feat in features.items():
# if feat is None:
# raise ValueError(f"None tensor for modality: {name}")

# if seq_length is None:
# seq_length = feat.size(1)
# elif feat.size(1) != seq_length:
# raise ValueError("All modalities must have same sequence length")

def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
""" Validates input feature dimensions and sequence lengths """

# Handle case where features might be a single tensor or empty
# Validate feature space dimensionality d_m
# Validate sequence length L
if not isinstance(features, dict) or not features:
if isinstance(features, torch.Tensor):
return # Skip validation for single tensor
raise ValueError("Empty features dict")

# Collect feature lengths for features with 2D+ tensors
# Validate temporal dimensions L_m across modalities
multi_dim_features = {}
for name, feat in features.items():
if feat is None:
raise ValueError(f"None tensor for modality: {name}")

# Only consider features with more than 1 dimension
if feat.ndim > 1:
multi_dim_features[name] = feat.size(1)

# If more than one unique length, raise an error
# Verification step
# L_i = L_j ∀i,j ∈ M
feature_lengths = set(multi_dim_features.values())
if len(feature_lengths) > 1:
raise ValueError(f"All modalities must have same sequence length. Current lengths: {multi_dim_features}")


def compute_modality_weights(
self,
features: torch.Tensor,
Expand All @@ -159,6 +153,7 @@ def compute_modality_weights(
weights = weights.masked_fill(~modality_mask.unsqueeze(-1), 0.0)

# Normalisation of weights
# α_m = w_m / Σ_j w_j
weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9)
return weights

Expand All @@ -175,7 +170,8 @@ def forward(
batch_size = next(iter(features.values())).size(0)
seq_len = next(iter(features.values())).size(1)

# Project each modality
# Apply modality-specific embeddings
# φ_m(x_m)
projected_features = []
for name, feat in features.items():
if self.feature_dims[name] > 0:
Expand All @@ -187,39 +183,47 @@ def forward(
if not projected_features:
raise ValueError("No valid features after projection")

# Stack features
# Tensor product of embedded features
# ⊗_{m∈M} φ_m(x_m)
feature_stack = torch.stack(projected_features, dim=1)

# Apply cross attention
# A(⊗_{m∈M} φ_m(x_m))
attended_features = []
for i in range(feature_stack.size(1)):
query = feature_stack[:, i]
key_value = feature_stack[:, [j for j in range(feature_stack.size(1)) if j != i]]
if key_value.size(1) > 0:
if feature_stack.size(1) > 1:

# Case |M| > 1
# Apply cross-modal attention A_c
key_value = feature_stack[:, [j for j in range(feature_stack.size(1)) if j != i]]
attended = self.cross_attention(query, key_value.reshape(-1, seq_len, self.hidden_dim),
key_value.reshape(-1, seq_len, self.hidden_dim))
attended_features.append(attended)
else:
attended_features.append(query)

# Average across modalities

# Case |M| = 1
# Apply self-attention A_s
attended = self.cross_attention(query, query, query)
attended_features.append(attended)

# Compute mean representation
# μ = 1/|M| Σ_{m∈M} A_m
attended_features = torch.stack(attended_features, dim=1)
attended_avg = attended_features.mean(dim=1)

# Mask attended features to match
# Apply attention mask
# M ∈ {0,1}^{B×L}
if modality_mask is not None:
# Create binary mask matching sequence length
seq_mask = torch.zeros((batch_size, seq_len), device=attended_avg.device).bool()
seq_mask[:, :modality_mask.size(1)] = modality_mask

# Compute weights on masked features
seq_mask[:, :modality_mask.size(1)] = modality_mask
weights = self.compute_modality_weights(attended_avg, seq_mask)
weights = weights.unsqueeze(1).expand(-1, attended_features.size(1), -1, 1)
else:
weights = self.compute_modality_weights(attended_avg)
weights = weights.unsqueeze(1).expand(-1, attended_features.size(1), -1, 1)

# Application of weighted features
# Apply dynamic modality weights
# w_m ∈ [0,1]
weighted_features = attended_features * weights

if self.fusion_method == "weighted_sum":
Expand All @@ -229,11 +233,14 @@ def forward(
fused = self.output_projection(concat)

# Application of residual
# r(x) = LayerNorm(x + μ(x))
if self.use_residual:
residual = feature_stack.mean(dim=1)
fused = self.layer_norm(fused + residual)

# Collapse sequence dimension for output
# Temporal pooling
# τ: ℝ^{B×L×h} → ℝ^{B×h}
fused = fused.mean(dim=1)

return fused
Expand All @@ -243,6 +250,10 @@ class ModalityGating(AbstractFusionBlock):
""" Implementation of modality specific gating mechanism """

# Input and hidden dimension definition
# Input spaces
# X_m ∈ ℝ^{d_m}
# Hidden space
# H ∈ ℝ^h
def __init__(
self,
feature_dims: Dict[str, int],
Expand All @@ -257,7 +268,8 @@ def __init__(
self.feature_dims = feature_dims
self.hidden_dim = hidden_dim

# Define gate networks for each modality
# Define gate networks for each modality - functions
# g_m: ℝ^{d_m} → [0,1]
self.gate_networks = nn.ModuleDict({
name: nn.Sequential(
nn.Linear(dim, hidden_dim),
Expand All @@ -279,7 +291,6 @@ def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
if feat is None:
raise ValueError(f"None tensor for modality: {name}")


def forward(
self,
features: Dict[str, torch.Tensor]
Expand All @@ -294,12 +305,14 @@ def forward(
if feat is not None and name in self.gate_networks:
batch_size, seq_len, feat_dim = feat.shape

# Gate computation sequence
# Compute gating activation
# α_m = σ(g_m(x_m))
flat_feat = feat.reshape(-1, feat_dim)
gate = self.gate_networks[name](flat_feat)
gate = gate.reshape(batch_size, seq_len, 1)

# Application of gating
# Apply multiplicative gating
# y_m = x_m ⊙ α_m
gated_features[name] = feat * gate

return gated_features

0 comments on commit b6fabde

Please sign in to comment.