Skip to content

Commit b6fabde

Browse files
committed
Fusion blocks update
1 parent 820d773 commit b6fabde

File tree

1 file changed

+57
-44
lines changed

1 file changed

+57
-44
lines changed

pvnet/models/multimodal/fusion_blocks.py

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
88
Aformentioned fusion blocks apply dynamic attention, weighted combinations and / or gating mechanisms for feature learning
99
10-
Summararily, this enables dynamic feature learning through attention based weighting and modality specific gating
10+
Summararily - enables dynamic feature learning through attention based weighting and modality specific gating
1111
"""
1212

1313

@@ -24,6 +24,8 @@ class AbstractFusionBlock(nn.Module, ABC):
2424
""" Abstract fusion base class definition """
2525

2626
# Forward pass
27+
# Function mapping
28+
# F: X → Y in fusion space
2729
@abstractmethod
2830
def forward(
2931
self,
@@ -36,8 +38,12 @@ class DynamicFusionModule(AbstractFusionBlock):
3638

3739
""" Implementation of dynamic multimodal fusion through cross attention and weighted combination """
3840

39-
# Input dimension specified and common embedding dimension
40-
# Quantity of attention heads also specified
41+
# Define feature dimensions
42+
# d_i ∈ ℝ^n
43+
# Shared latent space
44+
# ℝ^h
45+
# Attention mechanisms
46+
# A_i: ℝ^d → ℝ^{d/h}
4147
def __init__(
4248
self,
4349
feature_dims: Dict[str, int],
@@ -61,6 +67,7 @@ def __init__(
6167
raise ValueError(f"Invalid fusion method: {fusion_method}")
6268

6369
# Projections - modality specific
70+
# φ_m: ℝ^{d_m} → ℝ^h for m ∈ M
6471
self.projections = nn.ModuleDict({
6572
name: nn.Sequential(
6673
nn.Linear(dim, hidden_dim),
@@ -73,13 +80,15 @@ def __init__(
7380
})
7481

7582
# Attention - cross modal
83+
# ℝ^{d_m} → ℝ^h
7684
self.cross_attention = MultiheadAttention(
7785
embed_dim=hidden_dim,
7886
num_heads=num_heads,
7987
dropout=dropout
8088
)
8189

82-
# Weight computation network definition
90+
# Weight computation network definition - dynamic
91+
# W: ℝ^h → [0,1]
8392
self.weight_network = nn.Sequential(
8493
nn.Linear(hidden_dim, hidden_dim // 2),
8594
nn.ReLU(),
@@ -89,6 +98,7 @@ def __init__(
8998
)
9099

91100
# Optional concat projection
101+
# P: ℝ^{h|M|} → ℝ^h
92102
if fusion_method == "concat":
93103
self.output_projection = nn.Sequential(
94104
nn.Linear(hidden_dim * len(feature_dims), hidden_dim),
@@ -99,48 +109,32 @@ def __init__(
99109

100110
if use_residual:
101111
self.layer_norm = nn.LayerNorm(hidden_dim)
102-
103-
# def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
104-
# """ Validates input feature dimensions and sequence lengths """
105-
106-
# if not features:
107-
# raise ValueError("Empty features dict")
108-
109-
# seq_length = None
110-
# for name, feat in features.items():
111-
# if feat is None:
112-
# raise ValueError(f"None tensor for modality: {name}")
113-
114-
# if seq_length is None:
115-
# seq_length = feat.size(1)
116-
# elif feat.size(1) != seq_length:
117-
# raise ValueError("All modalities must have same sequence length")
118112

119113
def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
120114
""" Validates input feature dimensions and sequence lengths """
121115

122-
# Handle case where features might be a single tensor or empty
116+
# Validate feature space dimensionality d_m
117+
# Validate sequence length L
123118
if not isinstance(features, dict) or not features:
124119
if isinstance(features, torch.Tensor):
125120
return # Skip validation for single tensor
126121
raise ValueError("Empty features dict")
127122

128-
# Collect feature lengths for features with 2D+ tensors
123+
# Validate temporal dimensions L_m across modalities
129124
multi_dim_features = {}
130125
for name, feat in features.items():
131126
if feat is None:
132127
raise ValueError(f"None tensor for modality: {name}")
133128

134-
# Only consider features with more than 1 dimension
135129
if feat.ndim > 1:
136130
multi_dim_features[name] = feat.size(1)
137131

138-
# If more than one unique length, raise an error
132+
# Verification step
133+
# L_i = L_j ∀i,j ∈ M
139134
feature_lengths = set(multi_dim_features.values())
140135
if len(feature_lengths) > 1:
141136
raise ValueError(f"All modalities must have same sequence length. Current lengths: {multi_dim_features}")
142137

143-
144138
def compute_modality_weights(
145139
self,
146140
features: torch.Tensor,
@@ -159,6 +153,7 @@ def compute_modality_weights(
159153
weights = weights.masked_fill(~modality_mask.unsqueeze(-1), 0.0)
160154

161155
# Normalisation of weights
156+
# α_m = w_m / Σ_j w_j
162157
weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9)
163158
return weights
164159

@@ -175,7 +170,8 @@ def forward(
175170
batch_size = next(iter(features.values())).size(0)
176171
seq_len = next(iter(features.values())).size(1)
177172

178-
# Project each modality
173+
# Apply modality-specific embeddings
174+
# φ_m(x_m)
179175
projected_features = []
180176
for name, feat in features.items():
181177
if self.feature_dims[name] > 0:
@@ -187,39 +183,47 @@ def forward(
187183
if not projected_features:
188184
raise ValueError("No valid features after projection")
189185

190-
# Stack features
186+
# Tensor product of embedded features
187+
# ⊗_{m∈M} φ_m(x_m)
191188
feature_stack = torch.stack(projected_features, dim=1)
192-
189+
193190
# Apply cross attention
191+
# A(⊗_{m∈M} φ_m(x_m))
194192
attended_features = []
195193
for i in range(feature_stack.size(1)):
196194
query = feature_stack[:, i]
197-
key_value = feature_stack[:, [j for j in range(feature_stack.size(1)) if j != i]]
198-
if key_value.size(1) > 0:
195+
if feature_stack.size(1) > 1:
196+
197+
# Case |M| > 1
198+
# Apply cross-modal attention A_c
199+
key_value = feature_stack[:, [j for j in range(feature_stack.size(1)) if j != i]]
199200
attended = self.cross_attention(query, key_value.reshape(-1, seq_len, self.hidden_dim),
200201
key_value.reshape(-1, seq_len, self.hidden_dim))
201-
attended_features.append(attended)
202202
else:
203-
attended_features.append(query)
204-
205-
# Average across modalities
203+
204+
# Case |M| = 1
205+
# Apply self-attention A_s
206+
attended = self.cross_attention(query, query, query)
207+
attended_features.append(attended)
208+
209+
# Compute mean representation
210+
# μ = 1/|M| Σ_{m∈M} A_m
206211
attended_features = torch.stack(attended_features, dim=1)
207212
attended_avg = attended_features.mean(dim=1)
208213

209-
# Mask attended features to match
214+
# Apply attention mask
215+
# M ∈ {0,1}^{B×L}
210216
if modality_mask is not None:
211-
# Create binary mask matching sequence length
212217
seq_mask = torch.zeros((batch_size, seq_len), device=attended_avg.device).bool()
213-
seq_mask[:, :modality_mask.size(1)] = modality_mask
214-
215-
# Compute weights on masked features
218+
seq_mask[:, :modality_mask.size(1)] = modality_mask
216219
weights = self.compute_modality_weights(attended_avg, seq_mask)
217220
weights = weights.unsqueeze(1).expand(-1, attended_features.size(1), -1, 1)
218221
else:
219222
weights = self.compute_modality_weights(attended_avg)
220223
weights = weights.unsqueeze(1).expand(-1, attended_features.size(1), -1, 1)
221224

222-
# Application of weighted features
225+
# Apply dynamic modality weights
226+
# w_m ∈ [0,1]
223227
weighted_features = attended_features * weights
224228

225229
if self.fusion_method == "weighted_sum":
@@ -229,11 +233,14 @@ def forward(
229233
fused = self.output_projection(concat)
230234

231235
# Application of residual
236+
# r(x) = LayerNorm(x + μ(x))
232237
if self.use_residual:
233238
residual = feature_stack.mean(dim=1)
234239
fused = self.layer_norm(fused + residual)
235240

236241
# Collapse sequence dimension for output
242+
# Temporal pooling
243+
# τ: ℝ^{B×L×h} → ℝ^{B×h}
237244
fused = fused.mean(dim=1)
238245

239246
return fused
@@ -243,6 +250,10 @@ class ModalityGating(AbstractFusionBlock):
243250
""" Implementation of modality specific gating mechanism """
244251

245252
# Input and hidden dimension definition
253+
# Input spaces
254+
# X_m ∈ ℝ^{d_m}
255+
# Hidden space
256+
# H ∈ ℝ^h
246257
def __init__(
247258
self,
248259
feature_dims: Dict[str, int],
@@ -257,7 +268,8 @@ def __init__(
257268
self.feature_dims = feature_dims
258269
self.hidden_dim = hidden_dim
259270

260-
# Define gate networks for each modality
271+
# Define gate networks for each modality - functions
272+
# g_m: ℝ^{d_m} → [0,1]
261273
self.gate_networks = nn.ModuleDict({
262274
name: nn.Sequential(
263275
nn.Linear(dim, hidden_dim),
@@ -279,7 +291,6 @@ def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
279291
if feat is None:
280292
raise ValueError(f"None tensor for modality: {name}")
281293

282-
283294
def forward(
284295
self,
285296
features: Dict[str, torch.Tensor]
@@ -294,12 +305,14 @@ def forward(
294305
if feat is not None and name in self.gate_networks:
295306
batch_size, seq_len, feat_dim = feat.shape
296307

297-
# Gate computation sequence
308+
# Compute gating activation
309+
# α_m = σ(g_m(x_m))
298310
flat_feat = feat.reshape(-1, feat_dim)
299311
gate = self.gate_networks[name](flat_feat)
300312
gate = gate.reshape(batch_size, seq_len, 1)
301313

302-
# Application of gating
314+
# Apply multiplicative gating
315+
# y_m = x_m ⊙ α_m
303316
gated_features[name] = feat * gate
304317

305318
return gated_features

0 commit comments

Comments
 (0)