Skip to content

Commit f43b497

Browse files
committed
Fusion blocks fix
1 parent 09b5ff9 commit f43b497

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

pvnet/models/multimodal/fusion_blocks.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,23 +174,26 @@ def forward(
174174

175175
return fused_features
176176

177+
#########################################################################################################
178+
# URGENT FIX OF CLASS BELOW AND RE RUN TESTING !!!
179+
#########################################################################################################
177180

178-
class ModalityGating(AbstractFusionBlock):
179181

180-
""" Modality gating mechanism definition """
182+
class ModalityGating(AbstractFusionBlock):
181183
def __init__(
182184
self,
183185
feature_dims: Dict[str, int],
184186
hidden_dim: int = 256,
185187
dropout: float = 0.1
186188
):
187-
# Initialisation of modality gating module
188189
super().__init__()
189190
self.feature_dims = feature_dims
191+
self.hidden_dim = hidden_dim
190192

191193
# Create gate networks for each modality
192194
self.gate_networks = nn.ModuleDict({
193195
name: nn.Sequential(
196+
# Use the actual feature dimension as input size
194197
nn.Linear(dim, hidden_dim),
195198
nn.ReLU(),
196199
nn.Dropout(dropout),
@@ -205,14 +208,17 @@ def forward(
205208
self,
206209
features: Dict[str, torch.Tensor]
207210
) -> Dict[str, torch.Tensor]:
208-
209211
# Forward pass for modality gating
210212
gated_features = {}
211213

212-
# Gate value and subsequent application
214+
# Gate value computation and application
213215
for name, feat in features.items():
214-
if feat is not None and self.feature_dims.get(name, 0) > 0:
215-
gate = self.gate_networks[name](feat)
216-
gated_features[name] = feat * gate
216+
if feat is not None and name in self.gate_networks:
217+
# Ensure input tensor has correct shape
218+
if len(feat.shape) == 2:
219+
gate = self.gate_networks[name](feat)
220+
gated_features[name] = feat * gate
221+
else:
222+
raise ValueError(f"Expected 2D tensor for {name}, got shape {feat.shape}")
217223

218224
return gated_features

0 commit comments

Comments
 (0)