@@ -174,23 +174,26 @@ def forward(
174
174
175
175
return fused_features
176
176
177
+ #########################################################################################################
178
+ # URGENT FIX OF CLASS BELOW AND RE RUN TESTING !!!
179
+ #########################################################################################################
177
180
178
- class ModalityGating (AbstractFusionBlock ):
179
181
180
- """ Modality gating mechanism definition """
182
+ class ModalityGating ( AbstractFusionBlock ):
181
183
def __init__ (
182
184
self ,
183
185
feature_dims : Dict [str , int ],
184
186
hidden_dim : int = 256 ,
185
187
dropout : float = 0.1
186
188
):
187
- # Initialisation of modality gating module
188
189
super ().__init__ ()
189
190
self .feature_dims = feature_dims
191
+ self .hidden_dim = hidden_dim
190
192
191
193
# Create gate networks for each modality
192
194
self .gate_networks = nn .ModuleDict ({
193
195
name : nn .Sequential (
196
+ # Use the actual feature dimension as input size
194
197
nn .Linear (dim , hidden_dim ),
195
198
nn .ReLU (),
196
199
nn .Dropout (dropout ),
@@ -205,14 +208,17 @@ def forward(
205
208
self ,
206
209
features : Dict [str , torch .Tensor ]
207
210
) -> Dict [str , torch .Tensor ]:
208
-
209
211
# Forward pass for modality gating
210
212
gated_features = {}
211
213
212
- # Gate value and subsequent application
214
+ # Gate value computation and application
213
215
for name , feat in features .items ():
214
- if feat is not None and self .feature_dims .get (name , 0 ) > 0 :
215
- gate = self .gate_networks [name ](feat )
216
- gated_features [name ] = feat * gate
216
+ if feat is not None and name in self .gate_networks :
217
+ # Ensure input tensor has correct shape
218
+ if len (feat .shape ) == 2 :
219
+ gate = self .gate_networks [name ](feat )
220
+ gated_features [name ] = feat * gate
221
+ else :
222
+ raise ValueError (f"Expected 2D tensor for { name } , got shape { feat .shape } " )
217
223
218
224
return gated_features
0 commit comments