Skip to content

Commit

Permalink
Fusion blocks fix
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Jan 10, 2025
1 parent 09b5ff9 commit f43b497
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions pvnet/models/multimodal/fusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,26 @@ def forward(

return fused_features

#########################################################################################################
# URGENT FIX OF CLASS BELOW AND RE RUN TESTING !!!
#########################################################################################################

class ModalityGating(AbstractFusionBlock):

""" Modality gating mechanism definition """
class ModalityGating(AbstractFusionBlock):
def __init__(
self,
feature_dims: Dict[str, int],
hidden_dim: int = 256,
dropout: float = 0.1
):
# Initialisation of modality gating module
super().__init__()
self.feature_dims = feature_dims
self.hidden_dim = hidden_dim

# Create gate networks for each modality
self.gate_networks = nn.ModuleDict({
name: nn.Sequential(
# Use the actual feature dimension as input size
nn.Linear(dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
Expand All @@ -205,14 +208,17 @@ def forward(
self,
features: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:

# Forward pass for modality gating
gated_features = {}

# Gate value and subsequent application
# Gate value computation and application
for name, feat in features.items():
if feat is not None and self.feature_dims.get(name, 0) > 0:
gate = self.gate_networks[name](feat)
gated_features[name] = feat * gate
if feat is not None and name in self.gate_networks:
# Ensure input tensor has correct shape
if len(feat.shape) == 2:
gate = self.gate_networks[name](feat)
gated_features[name] = feat * gate
else:
raise ValueError(f"Expected 2D tensor for {name}, got shape {feat.shape}")

return gated_features

0 comments on commit f43b497

Please sign in to comment.