Skip to content

Commit cc42d38

Browse files
committed
Fusion blocks and testing finalisation
1 parent 3509f1b commit cc42d38

File tree

2 files changed

+100
-93
lines changed

2 files changed

+100
-93
lines changed

pvnet/models/multimodal/fusion_blocks.py

Lines changed: 83 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
Definition of foundational fusion mechanisms; DynamicFusionModule and ModalityGating
77
88
Aformentioned fusion blocks apply dynamic attention, weighted combinations and / or gating mechanisms for feature learning
9+
10+
Summararily, this enables dynamic feature learning through attention based weighting and modality specific gating
911
"""
1012

1113

@@ -31,6 +33,11 @@ def forward(
3133

3234

3335
class DynamicFusionModule(AbstractFusionBlock):
36+
37+
""" Implementation of dynamic multimodal fusion through cross attention and weighted combination """
38+
39+
# Input dimension specified and common embedding dimension
40+
# Quantity of attention heads also specified
3441
def __init__(
3542
self,
3643
feature_dims: Dict[str, int],
@@ -40,12 +47,10 @@ def __init__(
4047
fusion_method: str = "weighted_sum",
4148
use_residual: bool = True
4249
):
43-
nn.Module.__init__(self)
50+
super().__init__()
4451

45-
if hidden_dim <= 0:
46-
raise ValueError("hidden_dim must be positive")
47-
if num_heads <= 0:
48-
raise ValueError("num_heads must be positive")
52+
if hidden_dim <= 0 or num_heads <= 0:
53+
raise ValueError("hidden_dim and num_heads must be positive")
4954

5055
self.feature_dims = feature_dims
5156
self.hidden_dim = hidden_dim
@@ -55,7 +60,7 @@ def __init__(
5560
if fusion_method not in ["weighted_sum", "concat"]:
5661
raise ValueError(f"Invalid fusion method: {fusion_method}")
5762

58-
# Projections
63+
# Projections - modality specific
5964
self.projections = nn.ModuleDict({
6065
name: nn.Sequential(
6166
nn.Linear(dim, hidden_dim),
@@ -67,14 +72,14 @@ def __init__(
6772
if dim > 0
6873
})
6974

70-
# Attention
75+
# Attention - cross modal
7176
self.cross_attention = MultiheadAttention(
7277
embed_dim=hidden_dim,
7378
num_heads=num_heads,
7479
dropout=dropout
7580
)
7681

77-
# Weight network
82+
# Weight computation network definition
7883
self.weight_network = nn.Sequential(
7984
nn.Linear(hidden_dim, hidden_dim // 2),
8085
nn.ReLU(),
@@ -96,8 +101,10 @@ def __init__(
96101
self.layer_norm = nn.LayerNorm(hidden_dim)
97102

98103
def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
104+
""" Validates input feature dimensions and sequence lengths """
105+
99106
if not features:
100-
raise ValueError("Empty features dictionary")
107+
raise ValueError("Empty features dict")
101108

102109
seq_length = None
103110
for name, feat in features.items():
@@ -107,32 +114,26 @@ def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
107114
if seq_length is None:
108115
seq_length = feat.size(1)
109116
elif feat.size(1) != seq_length:
110-
raise ValueError("All modalities must have the same sequence length")
117+
raise ValueError("All modalities must have same sequence length")
111118

112119
def compute_modality_weights(
113120
self,
114121
features: torch.Tensor,
115122
modality_mask: Optional[torch.Tensor] = None
116123
) -> torch.Tensor:
117-
"""Compute weights for each feature.
118-
119-
Args:
120-
features: [batch_size, seq_len, hidden_dim] tensor
121-
modality_mask: Optional attention mask
122-
123-
Returns:
124-
[batch_size, seq_len, 1] tensor of weights
125-
"""
126-
# Compute weights for each feature
127-
flat_features = features.reshape(-1, features.size(-1)) # [B*S, H]
128-
weights = self.weight_network(flat_features) # [B*S, 1]
129-
weights = weights.reshape(features.size(0), features.size(1), 1) # [B, S, 1]
124+
125+
""" Computation of attention weights for each feature """
126+
127+
batch_size, seq_len = features.size(0), features.size(1)
128+
flat_features = features.reshape(-1, features.size(-1))
129+
weights = self.weight_network(flat_features)
130+
weights = weights.reshape(batch_size, seq_len, 1)
130131

131132
if modality_mask is not None:
132-
modality_mask = modality_mask.unsqueeze(-1) # [B, S, 1]
133-
weights = weights.masked_fill(~modality_mask, 0.0)
133+
weights = weights.reshape(batch_size, -1, 1)[:, :modality_mask.size(1), :]
134+
weights = weights.masked_fill(~modality_mask.unsqueeze(-1), 0.0)
134135

135-
# Normalize weights
136+
# Normalisation of weights
136137
weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-9)
137138
return weights
138139

@@ -141,15 +142,9 @@ def forward(
141142
features: Dict[str, torch.Tensor],
142143
modality_mask: Optional[torch.Tensor] = None
143144
) -> torch.Tensor:
144-
"""Forward pass
145-
146-
Args:
147-
features: Dict of [batch_size, seq_len, feature_dim] tensors
148-
modality_mask: Optional attention mask
149-
150-
Returns:
151-
[batch_size, hidden_dim] tensor if seq_len=1, else [batch_size, seq_len, hidden_dim]
152-
"""
145+
146+
""" Forward pass for dynamic fusion """
147+
153148
self._validate_features(features)
154149

155150
batch_size = next(iter(features.values())).size(0)
@@ -167,44 +162,64 @@ def forward(
167162
if not projected_features:
168163
raise ValueError("No valid features after projection")
169164

170-
# Stack and apply attention
171-
feature_stack = torch.stack(projected_features, dim=2) # [B, S, M, H]
165+
# Stack features
166+
feature_stack = torch.stack(projected_features, dim=1)
172167

173-
# Cross attention
174-
attended_features = self.cross_attention(
175-
feature_stack, feature_stack, feature_stack
176-
) # [B, S, M, H]
177-
178-
# Average across modalities first
179-
attended_avg = attended_features.mean(dim=2) # [B, S, H]
168+
# Apply cross attention
169+
attended_features = []
170+
for i in range(feature_stack.size(1)):
171+
query = feature_stack[:, i]
172+
key_value = feature_stack[:, [j for j in range(feature_stack.size(1)) if j != i]]
173+
if key_value.size(1) > 0:
174+
attended = self.cross_attention(query, key_value.reshape(-1, seq_len, self.hidden_dim),
175+
key_value.reshape(-1, seq_len, self.hidden_dim))
176+
attended_features.append(attended)
177+
else:
178+
attended_features.append(query)
179+
180+
# Average across modalities
181+
attended_features = torch.stack(attended_features, dim=1)
182+
attended_avg = attended_features.mean(dim=1)
180183

181-
# Compute weights on averaged features
182-
weights = self.compute_modality_weights(attended_avg, modality_mask) # [B, S, 1]
184+
# Mask attended features to match
185+
if modality_mask is not None:
186+
# Create binary mask matching sequence length
187+
seq_mask = torch.zeros((batch_size, seq_len), device=attended_avg.device).bool()
188+
seq_mask[:, :modality_mask.size(1)] = modality_mask
189+
190+
# Compute weights on masked features
191+
weights = self.compute_modality_weights(attended_avg, seq_mask)
192+
weights = weights.unsqueeze(1).expand(-1, attended_features.size(1), -1, 1)
193+
else:
194+
weights = self.compute_modality_weights(attended_avg)
195+
weights = weights.unsqueeze(1).expand(-1, attended_features.size(1), -1, 1)
183196

184-
# Apply weights
185-
weighted_features = attended_features * weights.unsqueeze(2) # [B, S, M, H]
197+
# Application of weighted features
198+
weighted_features = attended_features * weights
186199

187200
if self.fusion_method == "weighted_sum":
188-
# Sum across modalities
189-
fused = weighted_features.sum(dim=2) # [B, S, H]
201+
fused = weighted_features.sum(dim=1)
190202
else:
191-
# Concatenate modalities
192-
concat = weighted_features.reshape(batch_size, seq_len, -1) # [B, S, M*H]
193-
fused = self.output_projection(concat) # [B, S, H]
203+
concat = weighted_features.reshape(batch_size, seq_len, -1)
204+
fused = self.output_projection(concat)
194205

195-
# Apply residual if needed
206+
# Application of residual
196207
if self.use_residual:
197-
residual = feature_stack.mean(dim=2) # [B, S, H]
208+
residual = feature_stack.mean(dim=1)
198209
fused = self.layer_norm(fused + residual)
199210

200-
# Remove sequence dimension if length is 1
201-
if seq_len == 1:
202-
fused = fused.squeeze(1)
211+
# Collapse sequence dimension for output
212+
fused = fused.mean(dim=1)
203213

204214
return fused
205215

206216

217+
218+
207219
class ModalityGating(AbstractFusionBlock):
220+
""" Implementation of modality specific gating mechanism """
221+
222+
# Input and hidden dimension definition
208223
def __init__(
209224
self,
210225
feature_dims: Dict[str, int],
@@ -219,7 +234,7 @@ def __init__(
219234
self.feature_dims = feature_dims
220235
self.hidden_dim = hidden_dim
221236

222-
# Create gate networks for each modality
237+
# Define gate networks for each modality
223238
self.gate_networks = nn.ModuleDict({
224239
name: nn.Sequential(
225240
nn.Linear(dim, hidden_dim),
@@ -233,9 +248,10 @@ def __init__(
233248
})
234249

235250
def _validate_features(self, features: Dict[str, torch.Tensor]) -> None:
251+
""" Validation helper for input feature dict """
236252

237253
if not features:
238-
raise ValueError("Empty features dictionary")
254+
raise ValueError("Empty features dict")
239255
for name, feat in features.items():
240256
if feat is None:
241257
raise ValueError(f"None tensor for modality: {name}")
@@ -246,25 +262,21 @@ def forward(
246262
features: Dict[str, torch.Tensor]
247263
) -> Dict[str, torch.Tensor]:
248264

249-
self._validate_features(features)
265+
""" Application of modality specific gating """
250266

267+
self._validate_features(features)
251268
gated_features = {}
252269

253270
for name, feat in features.items():
254271
if feat is not None and name in self.gate_networks:
255-
# Handle 3D tensors (batch_size, sequence_length, feature_dim)
256272
batch_size, seq_len, feat_dim = feat.shape
257-
258-
# Reshape to (batch_size * seq_len, feature_dim)
259-
flat_feat = feat.reshape(-1, feat_dim)
260-
261-
# Compute gates
262-
gate = self.gate_networks[name](flat_feat)
263-
264-
# Reshape gates back to match input
273+
274+
# Gate computation sequence
275+
flat_feat = feat.reshape(-1, feat_dim)
276+
gate = self.gate_networks[name](flat_feat)
265277
gate = gate.reshape(batch_size, seq_len, 1)
266278

267-
# Apply gating
279+
# Application of gating
268280
gated_features[name] = feat * gate
269281

270282
return gated_features

tests/models/multimodal/test_fusion_blocks.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def attention_mask(config):
4747

4848

4949
# DynamicFusionModule Tests
50-
def test_dynamic_fusion_initialization(config):
51-
""" Verify initialization and parameter validation """
52-
# Test valid initialization
50+
def test_dynamic_fusion_initialisation(config):
51+
""" Verify initialstion and parameter validation """
52+
# Test valid initialisation
5353
fusion = DynamicFusionModule(
5454
feature_dims=config['feature_dims'],
5555
hidden_dim=config['hidden_dim'],
@@ -62,20 +62,13 @@ def test_dynamic_fusion_initialization(config):
6262
assert fusion.hidden_dim == config['hidden_dim']
6363
assert isinstance(fusion.cross_attention.embed_dim, int)
6464
assert isinstance(fusion.weight_network, torch.nn.Sequential)
65-
66-
# Test invalid hidden_dim
67-
with pytest.raises(ValueError, match="hidden_dim must be positive"):
65+
66+
# Test invalid hidden_dim and num_heads
67+
with pytest.raises(ValueError, match="hidden_dim and num_heads must be positive"):
6868
DynamicFusionModule(
6969
feature_dims=config['feature_dims'],
7070
hidden_dim=0
71-
)
72-
73-
# Test invalid num_heads
74-
with pytest.raises(ValueError, match="num_heads must be positive"):
75-
DynamicFusionModule(
76-
feature_dims=config['feature_dims'],
77-
num_heads=0
78-
)
71+
)
7972

8073

8174
def test_dynamic_fusion_feature_validation(config, multimodal_features):
@@ -86,7 +79,7 @@ def test_dynamic_fusion_feature_validation(config, multimodal_features):
8679
)
8780

8881
# Test empty features
89-
with pytest.raises(ValueError, match="Empty features dictionary"):
82+
with pytest.raises(ValueError, match="Empty features dict"):
9083
fusion({})
9184

9285
# Test None tensor
@@ -177,12 +170,12 @@ def test_dynamic_fusion_different_sequence_lengths(config):
177170
'audio': torch.randn(config['batch_size'], 12, config['feature_dims']['audio'])
178171
}
179172

180-
with pytest.raises(ValueError, match=r"All modalities must have the same sequence length"):
173+
with pytest.raises(ValueError, match="All modalities must have same sequence length"):
181174
output = fusion(varying_features)
182175

183176

184177
# ModalityGating Tests
185-
def test_modality_gating_initialization(config):
178+
def test_modality_gating_initialisation(config):
186179
""" Verify initialisation """
187180
gating = ModalityGating(
188181
feature_dims=config['feature_dims'],
@@ -193,6 +186,7 @@ def test_modality_gating_initialization(config):
193186
assert len(gating.gate_networks) == len(config['feature_dims'])
194187
for name, network in gating.gate_networks.items():
195188
assert isinstance(network, torch.nn.Sequential)
189+
196190
# Verify input dimension of first layer matches feature dimension
197191
assert network[0].in_features == config['feature_dims'][name]
198192

@@ -215,12 +209,13 @@ def test_modality_gating_forward(config, multimodal_features):
215209
# Verify output shapes and properties
216210
assert len(outputs) == len(multimodal_features)
217211
for modality, output in outputs.items():
218-
assert output.shape == multimodal_features[modality].shape # Should match 3D input shape
219-
assert len(output.shape) == 3 # Ensure 3D output (batch, sequence, features)
212+
assert output.shape == multimodal_features[modality].shape
213+
assert len(output.shape) == 3
220214
assert not torch.isnan(output).any()
221215
assert not torch.isinf(output).any()
216+
222217
# Verify gating values are between 0 and 1
223-
gates = output / (multimodal_features[modality] + 1e-8) # Avoid division by zero
218+
gates = output / (multimodal_features[modality] + 1e-8)
224219
assert torch.all((gates >= 0) & (gates <= 1 + 1e-6))
225220

226221

@@ -267,7 +262,7 @@ def test_modality_gating_edge_cases(config):
267262
gating = ModalityGating(feature_dims={'visual': 64})
268263

269264
# Empty input validation
270-
with pytest.raises(ValueError, match="Empty features dictionary"):
265+
with pytest.raises(ValueError, match="Empty features dict"):
271266
gating({})
272267

273268
# Test with single timestep
@@ -306,4 +301,4 @@ def test_modality_gating_different_sequence_lengths(config):
306301

307302
# Verify shapes are maintained
308303
assert outputs['visual'].shape == varying_features['visual'].shape
309-
assert outputs['text'].shape == varying_features['text'].shape
304+
assert outputs['text'].shape == varying_features['text'].shape

0 commit comments

Comments
 (0)