7
7
8
8
Aformentioned fusion blocks apply dynamic attention, weighted combinations and / or gating mechanisms for feature learning
9
9
10
- Summararily, this enables dynamic feature learning through attention based weighting and modality specific gating
10
+ Summararily - enables dynamic feature learning through attention based weighting and modality specific gating
11
11
"""
12
12
13
13
@@ -24,6 +24,8 @@ class AbstractFusionBlock(nn.Module, ABC):
24
24
""" Abstract fusion base class definition """
25
25
26
26
# Forward pass
27
+ # Function mapping
28
+ # F: X → Y in fusion space
27
29
@abstractmethod
28
30
def forward (
29
31
self ,
@@ -36,8 +38,12 @@ class DynamicFusionModule(AbstractFusionBlock):
36
38
37
39
""" Implementation of dynamic multimodal fusion through cross attention and weighted combination """
38
40
39
- # Input dimension specified and common embedding dimension
40
- # Quantity of attention heads also specified
41
+ # Define feature dimensions
42
+ # d_i ∈ ℝ^n
43
+ # Shared latent space
44
+ # ℝ^h
45
+ # Attention mechanisms
46
+ # A_i: ℝ^d → ℝ^{d/h}
41
47
def __init__ (
42
48
self ,
43
49
feature_dims : Dict [str , int ],
@@ -61,6 +67,7 @@ def __init__(
61
67
raise ValueError (f"Invalid fusion method: { fusion_method } " )
62
68
63
69
# Projections - modality specific
70
+ # φ_m: ℝ^{d_m} → ℝ^h for m ∈ M
64
71
self .projections = nn .ModuleDict ({
65
72
name : nn .Sequential (
66
73
nn .Linear (dim , hidden_dim ),
@@ -73,13 +80,15 @@ def __init__(
73
80
})
74
81
75
82
# Attention - cross modal
83
+ # ℝ^{d_m} → ℝ^h
76
84
self .cross_attention = MultiheadAttention (
77
85
embed_dim = hidden_dim ,
78
86
num_heads = num_heads ,
79
87
dropout = dropout
80
88
)
81
89
82
- # Weight computation network definition
90
+ # Weight computation network definition - dynamic
91
+ # W: ℝ^h → [0,1]
83
92
self .weight_network = nn .Sequential (
84
93
nn .Linear (hidden_dim , hidden_dim // 2 ),
85
94
nn .ReLU (),
@@ -89,6 +98,7 @@ def __init__(
89
98
)
90
99
91
100
# Optional concat projection
101
+ # P: ℝ^{h|M|} → ℝ^h
92
102
if fusion_method == "concat" :
93
103
self .output_projection = nn .Sequential (
94
104
nn .Linear (hidden_dim * len (feature_dims ), hidden_dim ),
@@ -99,48 +109,32 @@ def __init__(
99
109
100
110
if use_residual :
101
111
self .layer_norm = nn .LayerNorm (hidden_dim )
102
-
103
- # def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
104
- # """ Validates input feature dimensions and sequence lengths """
105
-
106
- # if not features:
107
- # raise ValueError("Empty features dict")
108
-
109
- # seq_length = None
110
- # for name, feat in features.items():
111
- # if feat is None:
112
- # raise ValueError(f"None tensor for modality: {name}")
113
-
114
- # if seq_length is None:
115
- # seq_length = feat.size(1)
116
- # elif feat.size(1) != seq_length:
117
- # raise ValueError("All modalities must have same sequence length")
118
112
119
113
def _validate_features (self , features : Dict [str , torch .Tensor ]) -> None :
120
114
""" Validates input feature dimensions and sequence lengths """
121
115
122
- # Handle case where features might be a single tensor or empty
116
+ # Validate feature space dimensionality d_m
117
+ # Validate sequence length L
123
118
if not isinstance (features , dict ) or not features :
124
119
if isinstance (features , torch .Tensor ):
125
120
return # Skip validation for single tensor
126
121
raise ValueError ("Empty features dict" )
127
122
128
- # Collect feature lengths for features with 2D+ tensors
123
+ # Validate temporal dimensions L_m across modalities
129
124
multi_dim_features = {}
130
125
for name , feat in features .items ():
131
126
if feat is None :
132
127
raise ValueError (f"None tensor for modality: { name } " )
133
128
134
- # Only consider features with more than 1 dimension
135
129
if feat .ndim > 1 :
136
130
multi_dim_features [name ] = feat .size (1 )
137
131
138
- # If more than one unique length, raise an error
132
+ # Verification step
133
+ # L_i = L_j ∀i,j ∈ M
139
134
feature_lengths = set (multi_dim_features .values ())
140
135
if len (feature_lengths ) > 1 :
141
136
raise ValueError (f"All modalities must have same sequence length. Current lengths: { multi_dim_features } " )
142
137
143
-
144
138
def compute_modality_weights (
145
139
self ,
146
140
features : torch .Tensor ,
@@ -159,6 +153,7 @@ def compute_modality_weights(
159
153
weights = weights .masked_fill (~ modality_mask .unsqueeze (- 1 ), 0.0 )
160
154
161
155
# Normalisation of weights
156
+ # α_m = w_m / Σ_j w_j
162
157
weights = weights / (weights .sum (dim = 1 , keepdim = True ) + 1e-9 )
163
158
return weights
164
159
@@ -175,7 +170,8 @@ def forward(
175
170
batch_size = next (iter (features .values ())).size (0 )
176
171
seq_len = next (iter (features .values ())).size (1 )
177
172
178
- # Project each modality
173
+ # Apply modality-specific embeddings
174
+ # φ_m(x_m)
179
175
projected_features = []
180
176
for name , feat in features .items ():
181
177
if self .feature_dims [name ] > 0 :
@@ -187,39 +183,47 @@ def forward(
187
183
if not projected_features :
188
184
raise ValueError ("No valid features after projection" )
189
185
190
- # Stack features
186
+ # Tensor product of embedded features
187
+ # ⊗_{m∈M} φ_m(x_m)
191
188
feature_stack = torch .stack (projected_features , dim = 1 )
192
-
189
+
193
190
# Apply cross attention
191
+ # A(⊗_{m∈M} φ_m(x_m))
194
192
attended_features = []
195
193
for i in range (feature_stack .size (1 )):
196
194
query = feature_stack [:, i ]
197
- key_value = feature_stack [:, [j for j in range (feature_stack .size (1 )) if j != i ]]
198
- if key_value .size (1 ) > 0 :
195
+ if feature_stack .size (1 ) > 1 :
196
+
197
+ # Case |M| > 1
198
+ # Apply cross-modal attention A_c
199
+ key_value = feature_stack [:, [j for j in range (feature_stack .size (1 )) if j != i ]]
199
200
attended = self .cross_attention (query , key_value .reshape (- 1 , seq_len , self .hidden_dim ),
200
201
key_value .reshape (- 1 , seq_len , self .hidden_dim ))
201
- attended_features .append (attended )
202
202
else :
203
- attended_features .append (query )
204
-
205
- # Average across modalities
203
+
204
+ # Case |M| = 1
205
+ # Apply self-attention A_s
206
+ attended = self .cross_attention (query , query , query )
207
+ attended_features .append (attended )
208
+
209
+ # Compute mean representation
210
+ # μ = 1/|M| Σ_{m∈M} A_m
206
211
attended_features = torch .stack (attended_features , dim = 1 )
207
212
attended_avg = attended_features .mean (dim = 1 )
208
213
209
- # Mask attended features to match
214
+ # Apply attention mask
215
+ # M ∈ {0,1}^{B×L}
210
216
if modality_mask is not None :
211
- # Create binary mask matching sequence length
212
217
seq_mask = torch .zeros ((batch_size , seq_len ), device = attended_avg .device ).bool ()
213
- seq_mask [:, :modality_mask .size (1 )] = modality_mask
214
-
215
- # Compute weights on masked features
218
+ seq_mask [:, :modality_mask .size (1 )] = modality_mask
216
219
weights = self .compute_modality_weights (attended_avg , seq_mask )
217
220
weights = weights .unsqueeze (1 ).expand (- 1 , attended_features .size (1 ), - 1 , 1 )
218
221
else :
219
222
weights = self .compute_modality_weights (attended_avg )
220
223
weights = weights .unsqueeze (1 ).expand (- 1 , attended_features .size (1 ), - 1 , 1 )
221
224
222
- # Application of weighted features
225
+ # Apply dynamic modality weights
226
+ # w_m ∈ [0,1]
223
227
weighted_features = attended_features * weights
224
228
225
229
if self .fusion_method == "weighted_sum" :
@@ -229,11 +233,14 @@ def forward(
229
233
fused = self .output_projection (concat )
230
234
231
235
# Application of residual
236
+ # r(x) = LayerNorm(x + μ(x))
232
237
if self .use_residual :
233
238
residual = feature_stack .mean (dim = 1 )
234
239
fused = self .layer_norm (fused + residual )
235
240
236
241
# Collapse sequence dimension for output
242
+ # Temporal pooling
243
+ # τ: ℝ^{B×L×h} → ℝ^{B×h}
237
244
fused = fused .mean (dim = 1 )
238
245
239
246
return fused
@@ -243,6 +250,10 @@ class ModalityGating(AbstractFusionBlock):
243
250
""" Implementation of modality specific gating mechanism """
244
251
245
252
# Input and hidden dimension definition
253
+ # Input spaces
254
+ # X_m ∈ ℝ^{d_m}
255
+ # Hidden space
256
+ # H ∈ ℝ^h
246
257
def __init__ (
247
258
self ,
248
259
feature_dims : Dict [str , int ],
@@ -257,7 +268,8 @@ def __init__(
257
268
self .feature_dims = feature_dims
258
269
self .hidden_dim = hidden_dim
259
270
260
- # Define gate networks for each modality
271
+ # Define gate networks for each modality - functions
272
+ # g_m: ℝ^{d_m} → [0,1]
261
273
self .gate_networks = nn .ModuleDict ({
262
274
name : nn .Sequential (
263
275
nn .Linear (dim , hidden_dim ),
@@ -279,7 +291,6 @@ def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
279
291
if feat is None :
280
292
raise ValueError (f"None tensor for modality: { name } " )
281
293
282
-
283
294
def forward (
284
295
self ,
285
296
features : Dict [str , torch .Tensor ]
@@ -294,12 +305,14 @@ def forward(
294
305
if feat is not None and name in self .gate_networks :
295
306
batch_size , seq_len , feat_dim = feat .shape
296
307
297
- # Gate computation sequence
308
+ # Compute gating activation
309
+ # α_m = σ(g_m(x_m))
298
310
flat_feat = feat .reshape (- 1 , feat_dim )
299
311
gate = self .gate_networks [name ](flat_feat )
300
312
gate = gate .reshape (batch_size , seq_len , 1 )
301
313
302
- # Application of gating
314
+ # Apply multiplicative gating
315
+ # y_m = x_m ⊙ α_m
303
316
gated_features [name ] = feat * gate
304
317
305
318
return gated_features
0 commit comments