20
20
def get_compatible_heads (dim : int , target_heads : int ) -> int :
21
21
""" Calculate largest compatible number of heads <= target_heads """
22
22
23
+ # Iterative reduction
24
+ # Obtain maximum divisible number of heads
25
+ # h ∈ ℕ : h ≤ target_heads ∧ dim mod h = 0
23
26
for h in range (min (target_heads , dim ), 0 , - 1 ):
24
27
if dim % h == 0 :
25
28
return h
@@ -32,11 +35,17 @@ class PVEncoder(nn.Module):
32
35
33
36
def __init__ (self , sequence_length : int , num_sites : int , out_features : int ):
34
37
super ().__init__ ()
38
+
39
+ # Temporal and spatial configuration parameters
40
+ # L: sequence length
41
+ # M: number of sites
35
42
self .sequence_length = sequence_length
36
43
self .num_sites = num_sites
37
44
self .out_features = out_features
38
45
39
46
# Basic feature extraction network
47
+ # φ: ℝ^M → ℝ^N
48
+ # Linear Transformation → Layer Normalization → ReLU → Dropout
40
49
self .encoder = nn .Sequential (
41
50
nn .Linear (num_sites , out_features ),
42
51
nn .LayerNorm (out_features ),
@@ -47,10 +56,12 @@ def __init__(self, sequence_length: int, num_sites: int, out_features: int):
47
56
def forward (self , x ):
48
57
49
58
# Sequential processing - maintain temporal order
59
+ # x ∈ ℝ^{B×L×M} → out ∈ ℝ^{B×L×N}
50
60
batch_size = x .shape [0 ]
51
61
out = []
52
62
for t in range (self .sequence_length ):
53
63
out .append (self .encoder (x [:, t ]))\
64
+
54
65
# Reshape maintaining sequence dimension
55
66
return torch .stack (out , dim = 1 )
56
67
@@ -84,17 +95,22 @@ def __init__(
84
95
)
85
96
86
97
# Dimension validation and compatibility
98
+ # Adjust hidden dimension to be divisible by sequence length
99
+ # H = feature_dim × sequence_length
87
100
if hidden_dim % sequence_length != 0 :
88
101
feature_dim = ((hidden_dim + sequence_length - 1 ) // sequence_length )
89
102
hidden_dim = feature_dim * sequence_length
90
103
else :
91
104
feature_dim = hidden_dim // sequence_length
92
105
93
- # Attention mechanism setup
106
+ # Attention head compatibility check
107
+ # Select maximum compatible head count
108
+ # h ∈ ℕ : h ≤ num_heads ∧ feature_dim mod h = 0
94
109
attention_heads = cross_attention .get ('num_heads' , num_heads )
95
110
attention_heads = get_compatible_heads (feature_dim , attention_heads )
96
111
97
- # Feature dimension adjustment for attention
112
+ # Dimension adjustment for attention mechanism
113
+ # Ensure feature dimension is compatible with attention heads
98
114
if feature_dim < attention_heads :
99
115
feature_dim = attention_heads
100
116
hidden_dim = feature_dim * sequence_length
@@ -164,15 +180,16 @@ def __init__(
164
180
'hidden_dim' : feature_dim
165
181
})
166
182
self .gating = ModalityGating (** gating_config )
167
-
183
+
168
184
# Cross modal attention mechanism
169
- self .use_cross_attention = use_cross_attention
170
- if use_cross_attention :
185
+ self .use_cross_attention = use_cross_attention and len ( modality_channels ) > 1
186
+ if self . use_cross_attention :
171
187
attention_config = cross_attention .copy ()
172
188
attention_config .update ({
173
189
'embed_dim' : feature_dim ,
174
190
'num_heads' : attention_heads ,
175
- 'dropout' : dropout
191
+ 'dropout' : dropout ,
192
+ 'num_modalities' : len (modality_channels )
176
193
})
177
194
self .cross_attention = CrossModalAttention (** attention_config )
178
195
@@ -195,45 +212,84 @@ def __init__(
195
212
nn .Linear (fc_features , out_features ),
196
213
nn .ELU (),
197
214
)
198
-
215
+
199
216
def forward (
200
217
self ,
201
218
inputs : Dict [str , torch .Tensor ],
202
219
mask : Optional [torch .Tensor ] = None
203
220
) -> torch .Tensor :
204
221
205
- """ Dynamic fusion forward pass implementation """
206
-
222
+ # Encoded features dictionary
223
+ # M ∈ {x_m | m ∈ Modalities}
207
224
encoded_features = {}
208
225
209
226
# Modality specific encoding
227
+ # x_m ∈ ℝ^{B×L×C_m} → encoded ∈ ℝ^{B×L×D}
210
228
for modality , x in inputs .items ():
211
229
if modality not in self .modality_encoders or x is None :
212
230
continue
213
-
231
+
214
232
# Feature extraction and projection
215
233
encoded = self .modality_encoders [modality ](x )
234
+ print (f"Encoded { modality } shape: { encoded .shape } " )
235
+
236
+ # Temporal projection across sequence
237
+ # π: ℝ^{B×L×D} → ℝ^{B×L×D}
216
238
projected = torch .stack ([
217
239
self .feature_projections [modality ](encoded [:, t ])
218
240
for t in range (self .sequence_length )
219
241
], dim = 1 )
242
+ print (f"Projected { modality } shape: { projected .shape } " )
220
243
221
244
encoded_features [modality ] = projected
222
-
245
+
246
+ # Validation of encoded feature space
247
+ # |M| > 0
223
248
if not encoded_features :
224
249
raise ValueError ("No valid features after encoding" )
225
250
226
251
# Apply modality interaction mechanisms
227
252
if self .use_gating :
253
+
254
+ # g: M → M̂
255
+ # Adaptive feature transformation with learned gates
228
256
encoded_features = self .gating (encoded_features )
229
-
230
- if self .use_cross_attention and len (encoded_features ) > 1 :
231
- encoded_features = self .cross_attention (encoded_features , mask )
232
-
257
+ print (f"After gating, encoded_features shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
258
+
259
+ # Cross-modal attention mechanism
260
+ if self .use_cross_attention :
261
+ if len (encoded_features ) > 1 :
262
+
263
+ # Multi-modal cross attention
264
+ encoded_features = self .cross_attention (encoded_features , mask )
265
+ else :
266
+
267
+ # For single modality, apply self-attention instead
268
+ for key in encoded_features :
269
+ encoded_features [key ] = encoded_features [key ] # Identity mapping
270
+
271
+ print (f"After cross-modal attention - encoded_features shapes: { [encoded_features [mod ].shape for mod in encoded_features ]} " )
272
+
233
273
# Feature fusion and output generation
234
274
fused_features = self .fusion_module (encoded_features , mask )
275
+ print (f"Fused features shape: { fused_features .shape } " )
276
+
277
+ # Ensure input to final_block matches hidden_dim
278
+ # Ensure z ∈ ℝ^{B×H}, H: hidden dimension
235
279
batch_size = fused_features .size (0 )
236
- fused_features = fused_features .repeat (1 , self .sequence_length )
280
+
281
+ # Repeat the features to match the expected hidden dimension
282
+ if fused_features .size (1 ) != self .hidden_dim :
283
+ fused_features = fused_features .repeat (1 , self .hidden_dim // fused_features .size (1 ))
284
+
285
+ # Precision projection if dimension mismatch persists
286
+ # π_H: ℝ^k → ℝ^H
287
+ if fused_features .size (1 ) != self .hidden_dim :
288
+ projection = nn .Linear (fused_features .size (1 ), self .hidden_dim ).to (fused_features .device )
289
+ fused_features = projection (fused_features )
290
+
291
+ # Final output generation
292
+ # ψ: ℝ^H → ℝ^M, M: output features
237
293
output = self .final_block (fused_features )
238
294
239
295
return output
@@ -246,6 +302,8 @@ def __init__(self, *args, **kwargs):
246
302
super ().__init__ (* args , ** kwargs )
247
303
248
304
# Enhanced projection with residual pathways
305
+ # With residual transformation
306
+ # φ_m: ℝ^H → ℝ^H
249
307
self .feature_projections = nn .ModuleDict ({
250
308
modality : nn .Sequential (
251
309
nn .LayerNorm (self .hidden_dim ),
@@ -266,28 +324,35 @@ def forward(
266
324
267
325
""" Forward implementation with residual pathways """
268
326
327
+ # Encoded features dictionary
269
328
encoded_features = {}
270
329
271
330
# Feature extraction with residual connections
331
+ # x_m + R_m(x_m)
272
332
for modality , x in inputs .items ():
273
333
if modality not in self .modality_encoders or x is None :
274
334
continue
275
335
276
- encoded = self .modality_encoders [modality ](x )
336
+ encoded = self .modality_encoders [modality ](x )
337
+
338
+ # Residual connection
339
+ # x_m ⊕ R_m(x_m)
277
340
projected = encoded + self .feature_projections [modality ](encoded )
278
341
encoded_features [modality ] = projected
279
342
280
343
if not encoded_features :
281
344
raise ValueError ("No valid features after encoding" )
282
345
283
346
# Gating with residual pathways
347
+ # g_m: x_m ⊕ g(x_m)
284
348
if self .use_gating :
285
349
gated_features = self .gating (encoded_features )
286
350
for modality in encoded_features :
287
351
gated_features [modality ] = gated_features [modality ] + encoded_features [modality ]
288
352
encoded_features = gated_features
289
353
290
354
# Attention with residual pathways
355
+ # A_m: x_m ⊕ A(x_m)
291
356
if self .use_cross_attention and len (encoded_features ) > 1 :
292
357
attended_features = self .cross_attention (encoded_features , mask )
293
358
for modality in encoded_features :
0 commit comments