6
6
Definition of foundational fusion mechanisms; DynamicFusionModule and ModalityGating
7
7
8
8
Aformentioned fusion blocks apply dynamic attention, weighted combinations and / or gating mechanisms for feature learning
9
+
10
+ Summararily, this enables dynamic feature learning through attention based weighting and modality specific gating
9
11
"""
10
12
11
13
@@ -31,6 +33,11 @@ def forward(
31
33
32
34
33
35
class DynamicFusionModule (AbstractFusionBlock ):
36
+
37
+ """ Implementation of dynamic multimodal fusion through cross attention and weighted combination """
38
+
39
+ # Input dimension specified and common embedding dimension
40
+ # Quantity of attention heads also specified
34
41
def __init__ (
35
42
self ,
36
43
feature_dims : Dict [str , int ],
@@ -40,12 +47,10 @@ def __init__(
40
47
fusion_method : str = "weighted_sum" ,
41
48
use_residual : bool = True
42
49
):
43
- nn . Module . __init__ (self )
50
+ super (). __init__ ()
44
51
45
- if hidden_dim <= 0 :
46
- raise ValueError ("hidden_dim must be positive" )
47
- if num_heads <= 0 :
48
- raise ValueError ("num_heads must be positive" )
52
+ if hidden_dim <= 0 or num_heads <= 0 :
53
+ raise ValueError ("hidden_dim and num_heads must be positive" )
49
54
50
55
self .feature_dims = feature_dims
51
56
self .hidden_dim = hidden_dim
@@ -55,7 +60,7 @@ def __init__(
55
60
if fusion_method not in ["weighted_sum" , "concat" ]:
56
61
raise ValueError (f"Invalid fusion method: { fusion_method } " )
57
62
58
- # Projections
63
+ # Projections - modality specific
59
64
self .projections = nn .ModuleDict ({
60
65
name : nn .Sequential (
61
66
nn .Linear (dim , hidden_dim ),
@@ -67,14 +72,14 @@ def __init__(
67
72
if dim > 0
68
73
})
69
74
70
- # Attention
75
+ # Attention - cross modal
71
76
self .cross_attention = MultiheadAttention (
72
77
embed_dim = hidden_dim ,
73
78
num_heads = num_heads ,
74
79
dropout = dropout
75
80
)
76
81
77
- # Weight network
82
+ # Weight computation network definition
78
83
self .weight_network = nn .Sequential (
79
84
nn .Linear (hidden_dim , hidden_dim // 2 ),
80
85
nn .ReLU (),
@@ -96,8 +101,10 @@ def __init__(
96
101
self .layer_norm = nn .LayerNorm (hidden_dim )
97
102
98
103
def _validate_features (self , features : Dict [str , torch .Tensor ]) -> None :
104
+ """ Validates input feature dimensions and sequence lengths """
105
+
99
106
if not features :
100
- raise ValueError ("Empty features dictionary " )
107
+ raise ValueError ("Empty features dict " )
101
108
102
109
seq_length = None
103
110
for name , feat in features .items ():
@@ -107,32 +114,26 @@ def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
107
114
if seq_length is None :
108
115
seq_length = feat .size (1 )
109
116
elif feat .size (1 ) != seq_length :
110
- raise ValueError ("All modalities must have the same sequence length" )
117
+ raise ValueError ("All modalities must have same sequence length" )
111
118
112
119
def compute_modality_weights (
113
120
self ,
114
121
features : torch .Tensor ,
115
122
modality_mask : Optional [torch .Tensor ] = None
116
123
) -> torch .Tensor :
117
- """Compute weights for each feature.
118
-
119
- Args:
120
- features: [batch_size, seq_len, hidden_dim] tensor
121
- modality_mask: Optional attention mask
122
-
123
- Returns:
124
- [batch_size, seq_len, 1] tensor of weights
125
- """
126
- # Compute weights for each feature
127
- flat_features = features .reshape (- 1 , features .size (- 1 )) # [B*S, H]
128
- weights = self .weight_network (flat_features ) # [B*S, 1]
129
- weights = weights .reshape (features .size (0 ), features .size (1 ), 1 ) # [B, S, 1]
124
+
125
+ """ Computation of attention weights for each feature """
126
+
127
+ batch_size , seq_len = features .size (0 ), features .size (1 )
128
+ flat_features = features .reshape (- 1 , features .size (- 1 ))
129
+ weights = self .weight_network (flat_features )
130
+ weights = weights .reshape (batch_size , seq_len , 1 )
130
131
131
132
if modality_mask is not None :
132
- modality_mask = modality_mask . unsqueeze ( - 1 ) # [B, S, 1 ]
133
- weights = weights .masked_fill (~ modality_mask , 0.0 )
133
+ weights = weights . reshape ( batch_size , - 1 , 1 )[:, : modality_mask . size ( 1 ), : ]
134
+ weights = weights .masked_fill (~ modality_mask . unsqueeze ( - 1 ) , 0.0 )
134
135
135
- # Normalize weights
136
+ # Normalisation of weights
136
137
weights = weights / (weights .sum (dim = 1 , keepdim = True ) + 1e-9 )
137
138
return weights
138
139
@@ -141,15 +142,9 @@ def forward(
141
142
features : Dict [str , torch .Tensor ],
142
143
modality_mask : Optional [torch .Tensor ] = None
143
144
) -> torch .Tensor :
144
- """Forward pass
145
-
146
- Args:
147
- features: Dict of [batch_size, seq_len, feature_dim] tensors
148
- modality_mask: Optional attention mask
149
-
150
- Returns:
151
- [batch_size, hidden_dim] tensor if seq_len=1, else [batch_size, seq_len, hidden_dim]
152
- """
145
+
146
+ """ Forward pass for dynamic fusion """
147
+
153
148
self ._validate_features (features )
154
149
155
150
batch_size = next (iter (features .values ())).size (0 )
@@ -167,44 +162,64 @@ def forward(
167
162
if not projected_features :
168
163
raise ValueError ("No valid features after projection" )
169
164
170
- # Stack and apply attention
171
- feature_stack = torch .stack (projected_features , dim = 2 ) # [B, S, M, H]
165
+ # Stack features
166
+ feature_stack = torch .stack (projected_features , dim = 1 )
172
167
173
- # Cross attention
174
- attended_features = self .cross_attention (
175
- feature_stack , feature_stack , feature_stack
176
- ) # [B, S, M, H]
177
-
178
- # Average across modalities first
179
- attended_avg = attended_features .mean (dim = 2 ) # [B, S, H]
168
+ # Apply cross attention
169
+ attended_features = []
170
+ for i in range (feature_stack .size (1 )):
171
+ query = feature_stack [:, i ]
172
+ key_value = feature_stack [:, [j for j in range (feature_stack .size (1 )) if j != i ]]
173
+ if key_value .size (1 ) > 0 :
174
+ attended = self .cross_attention (query , key_value .reshape (- 1 , seq_len , self .hidden_dim ),
175
+ key_value .reshape (- 1 , seq_len , self .hidden_dim ))
176
+ attended_features .append (attended )
177
+ else :
178
+ attended_features .append (query )
179
+
180
+ # Average across modalities
181
+ attended_features = torch .stack (attended_features , dim = 1 )
182
+ attended_avg = attended_features .mean (dim = 1 )
180
183
181
- # Compute weights on averaged features
182
- weights = self .compute_modality_weights (attended_avg , modality_mask ) # [B, S, 1]
184
+ # Mask attended features to match
185
+ if modality_mask is not None :
186
+ # Create binary mask matching sequence length
187
+ seq_mask = torch .zeros ((batch_size , seq_len ), device = attended_avg .device ).bool ()
188
+ seq_mask [:, :modality_mask .size (1 )] = modality_mask
189
+
190
+ # Compute weights on masked features
191
+ weights = self .compute_modality_weights (attended_avg , seq_mask )
192
+ weights = weights .unsqueeze (1 ).expand (- 1 , attended_features .size (1 ), - 1 , 1 )
193
+ else :
194
+ weights = self .compute_modality_weights (attended_avg )
195
+ weights = weights .unsqueeze (1 ).expand (- 1 , attended_features .size (1 ), - 1 , 1 )
183
196
184
- # Apply weights
185
- weighted_features = attended_features * weights . unsqueeze ( 2 ) # [B, S, M, H]
197
+ # Application of weighted features
198
+ weighted_features = attended_features * weights
186
199
187
200
if self .fusion_method == "weighted_sum" :
188
- # Sum across modalities
189
- fused = weighted_features .sum (dim = 2 ) # [B, S, H]
201
+ fused = weighted_features .sum (dim = 1 )
190
202
else :
191
- # Concatenate modalities
192
- concat = weighted_features .reshape (batch_size , seq_len , - 1 ) # [B, S, M*H]
193
- fused = self .output_projection (concat ) # [B, S, H]
203
+ concat = weighted_features .reshape (batch_size , seq_len , - 1 )
204
+ fused = self .output_projection (concat )
194
205
195
- # Apply residual if needed
206
+ # Application of residual
196
207
if self .use_residual :
197
- residual = feature_stack .mean (dim = 2 ) # [B, S, H]
208
+ residual = feature_stack .mean (dim = 1 )
198
209
fused = self .layer_norm (fused + residual )
199
210
200
- # Remove sequence dimension if length is 1
201
- if seq_len == 1 :
202
- fused = fused .squeeze (1 )
211
+ # Collapse sequence dimension for output
212
+ fused = fused .mean (dim = 1 )
203
213
204
214
return fused
205
215
206
216
217
+
218
+
207
219
class ModalityGating (AbstractFusionBlock ):
220
+ """ Implementation of modality specific gating mechanism """
221
+
222
+ # Input and hidden dimension definition
208
223
def __init__ (
209
224
self ,
210
225
feature_dims : Dict [str , int ],
@@ -219,7 +234,7 @@ def __init__(
219
234
self .feature_dims = feature_dims
220
235
self .hidden_dim = hidden_dim
221
236
222
- # Create gate networks for each modality
237
+ # Define gate networks for each modality
223
238
self .gate_networks = nn .ModuleDict ({
224
239
name : nn .Sequential (
225
240
nn .Linear (dim , hidden_dim ),
@@ -233,9 +248,10 @@ def __init__(
233
248
})
234
249
235
250
def _validate_features (self , features : Dict [str , torch .Tensor ]) -> None :
251
+ """ Validation helper for input feature dict """
236
252
237
253
if not features :
238
- raise ValueError ("Empty features dictionary " )
254
+ raise ValueError ("Empty features dict " )
239
255
for name , feat in features .items ():
240
256
if feat is None :
241
257
raise ValueError (f"None tensor for modality: { name } " )
@@ -246,25 +262,21 @@ def forward(
246
262
features : Dict [str , torch .Tensor ]
247
263
) -> Dict [str , torch .Tensor ]:
248
264
249
- self . _validate_features ( features )
265
+ """ Application of modality specific gating """
250
266
267
+ self ._validate_features (features )
251
268
gated_features = {}
252
269
253
270
for name , feat in features .items ():
254
271
if feat is not None and name in self .gate_networks :
255
- # Handle 3D tensors (batch_size, sequence_length, feature_dim)
256
272
batch_size , seq_len , feat_dim = feat .shape
257
-
258
- # Reshape to (batch_size * seq_len, feature_dim)
259
- flat_feat = feat .reshape (- 1 , feat_dim )
260
-
261
- # Compute gates
262
- gate = self .gate_networks [name ](flat_feat )
263
-
264
- # Reshape gates back to match input
273
+
274
+ # Gate computation sequence
275
+ flat_feat = feat .reshape (- 1 , feat_dim )
276
+ gate = self .gate_networks [name ](flat_feat )
265
277
gate = gate .reshape (batch_size , seq_len , 1 )
266
278
267
- # Apply gating
279
+ # Application of gating
268
280
gated_features [name ] = feat * gate
269
281
270
282
return gated_features
0 commit comments