Skip to content

Commit af54862

Browse files
committed
Dynami encoder update
1 parent b6fabde commit af54862

File tree

2 files changed

+90
-49
lines changed

2 files changed

+90
-49
lines changed

pvnet/models/multimodal/encoders/dynamic_encoder.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
def get_compatible_heads(dim: int, target_heads: int) -> int:
2121
""" Calculate largest compatible number of heads <= target_heads """
2222

23+
# Iterative reduction
24+
# Obtain maximum divisible number of heads
25+
# h ∈ ℕ : h ≤ target_heads ∧ dim mod h = 0
2326
for h in range(min(target_heads, dim), 0, -1):
2427
if dim % h == 0:
2528
return h
@@ -32,11 +35,17 @@ class PVEncoder(nn.Module):
3235

3336
def __init__(self, sequence_length: int, num_sites: int, out_features: int):
3437
super().__init__()
38+
39+
# Temporal and spatial configuration parameters
40+
# L: sequence length
41+
# M: number of sites
3542
self.sequence_length = sequence_length
3643
self.num_sites = num_sites
3744
self.out_features = out_features
3845

3946
# Basic feature extraction network
47+
# φ: ℝ^M → ℝ^N
48+
# Linear Transformation → Layer Normalization → ReLU → Dropout
4049
self.encoder = nn.Sequential(
4150
nn.Linear(num_sites, out_features),
4251
nn.LayerNorm(out_features),
@@ -47,10 +56,12 @@ def __init__(self, sequence_length: int, num_sites: int, out_features: int):
4756
def forward(self, x):
4857

4958
# Sequential processing - maintain temporal order
59+
# x ∈ ℝ^{B×L×M} → out ∈ ℝ^{B×L×N}
5060
batch_size = x.shape[0]
5161
out = []
5262
for t in range(self.sequence_length):
5363
out.append(self.encoder(x[:, t]))\
64+
5465
# Reshape maintaining sequence dimension
5566
return torch.stack(out, dim=1)
5667

@@ -84,17 +95,22 @@ def __init__(
8495
)
8596

8697
# Dimension validation and compatibility
98+
# Adjust hidden dimension to be divisible by sequence length
99+
# H = feature_dim × sequence_length
87100
if hidden_dim % sequence_length != 0:
88101
feature_dim = ((hidden_dim + sequence_length - 1) // sequence_length)
89102
hidden_dim = feature_dim * sequence_length
90103
else:
91104
feature_dim = hidden_dim // sequence_length
92105

93-
# Attention mechanism setup
106+
# Attention head compatibility check
107+
# Select maximum compatible head count
108+
# h ∈ ℕ : h ≤ num_heads ∧ feature_dim mod h = 0
94109
attention_heads = cross_attention.get('num_heads', num_heads)
95110
attention_heads = get_compatible_heads(feature_dim, attention_heads)
96111

97-
# Feature dimension adjustment for attention
112+
# Dimension adjustment for attention mechanism
113+
# Ensure feature dimension is compatible with attention heads
98114
if feature_dim < attention_heads:
99115
feature_dim = attention_heads
100116
hidden_dim = feature_dim * sequence_length
@@ -164,15 +180,16 @@ def __init__(
164180
'hidden_dim': feature_dim
165181
})
166182
self.gating = ModalityGating(**gating_config)
167-
183+
168184
# Cross modal attention mechanism
169-
self.use_cross_attention = use_cross_attention
170-
if use_cross_attention:
185+
self.use_cross_attention = use_cross_attention and len(modality_channels) > 1
186+
if self.use_cross_attention:
171187
attention_config = cross_attention.copy()
172188
attention_config.update({
173189
'embed_dim': feature_dim,
174190
'num_heads': attention_heads,
175-
'dropout': dropout
191+
'dropout': dropout,
192+
'num_modalities': len(modality_channels)
176193
})
177194
self.cross_attention = CrossModalAttention(**attention_config)
178195

@@ -195,45 +212,84 @@ def __init__(
195212
nn.Linear(fc_features, out_features),
196213
nn.ELU(),
197214
)
198-
215+
199216
def forward(
200217
self,
201218
inputs: Dict[str, torch.Tensor],
202219
mask: Optional[torch.Tensor] = None
203220
) -> torch.Tensor:
204221

205-
""" Dynamic fusion forward pass implementation """
206-
222+
# Encoded features dictionary
223+
# M ∈ {x_m | m ∈ Modalities}
207224
encoded_features = {}
208225

209226
# Modality specific encoding
227+
# x_m ∈ ℝ^{B×L×C_m} → encoded ∈ ℝ^{B×L×D}
210228
for modality, x in inputs.items():
211229
if modality not in self.modality_encoders or x is None:
212230
continue
213-
231+
214232
# Feature extraction and projection
215233
encoded = self.modality_encoders[modality](x)
234+
print(f"Encoded {modality} shape: {encoded.shape}")
235+
236+
# Temporal projection across sequence
237+
# π: ℝ^{B×L×D} → ℝ^{B×L×D}
216238
projected = torch.stack([
217239
self.feature_projections[modality](encoded[:, t])
218240
for t in range(self.sequence_length)
219241
], dim=1)
242+
print(f"Projected {modality} shape: {projected.shape}")
220243

221244
encoded_features[modality] = projected
222-
245+
246+
# Validation of encoded feature space
247+
# |M| > 0
223248
if not encoded_features:
224249
raise ValueError("No valid features after encoding")
225250

226251
# Apply modality interaction mechanisms
227252
if self.use_gating:
253+
254+
# g: M → M̂
255+
# Adaptive feature transformation with learned gates
228256
encoded_features = self.gating(encoded_features)
229-
230-
if self.use_cross_attention and len(encoded_features) > 1:
231-
encoded_features = self.cross_attention(encoded_features, mask)
232-
257+
print(f"After gating, encoded_features shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
258+
259+
# Cross-modal attention mechanism
260+
if self.use_cross_attention:
261+
if len(encoded_features) > 1:
262+
263+
# Multi-modal cross attention
264+
encoded_features = self.cross_attention(encoded_features, mask)
265+
else:
266+
267+
# For single modality, apply self-attention instead
268+
for key in encoded_features:
269+
encoded_features[key] = encoded_features[key] # Identity mapping
270+
271+
print(f"After cross-modal attention - encoded_features shapes: {[encoded_features[mod].shape for mod in encoded_features]}")
272+
233273
# Feature fusion and output generation
234274
fused_features = self.fusion_module(encoded_features, mask)
275+
print(f"Fused features shape: {fused_features.shape}")
276+
277+
# Ensure input to final_block matches hidden_dim
278+
# Ensure z ∈ ℝ^{B×H}, H: hidden dimension
235279
batch_size = fused_features.size(0)
236-
fused_features = fused_features.repeat(1, self.sequence_length)
280+
281+
# Repeat the features to match the expected hidden dimension
282+
if fused_features.size(1) != self.hidden_dim:
283+
fused_features = fused_features.repeat(1, self.hidden_dim // fused_features.size(1))
284+
285+
# Precision projection if dimension mismatch persists
286+
# π_H: ℝ^k → ℝ^H
287+
if fused_features.size(1) != self.hidden_dim:
288+
projection = nn.Linear(fused_features.size(1), self.hidden_dim).to(fused_features.device)
289+
fused_features = projection(fused_features)
290+
291+
# Final output generation
292+
# ψ: ℝ^H → ℝ^M, M: output features
237293
output = self.final_block(fused_features)
238294

239295
return output
@@ -246,6 +302,8 @@ def __init__(self, *args, **kwargs):
246302
super().__init__(*args, **kwargs)
247303

248304
# Enhanced projection with residual pathways
305+
# With residual transformation
306+
# φ_m: ℝ^H → ℝ^H
249307
self.feature_projections = nn.ModuleDict({
250308
modality: nn.Sequential(
251309
nn.LayerNorm(self.hidden_dim),
@@ -266,28 +324,35 @@ def forward(
266324

267325
""" Forward implementation with residual pathways """
268326

327+
# Encoded features dictionary
269328
encoded_features = {}
270329

271330
# Feature extraction with residual connections
331+
# x_m + R_m(x_m)
272332
for modality, x in inputs.items():
273333
if modality not in self.modality_encoders or x is None:
274334
continue
275335

276-
encoded = self.modality_encoders[modality](x)
336+
encoded = self.modality_encoders[modality](x)
337+
338+
# Residual connection
339+
# x_m ⊕ R_m(x_m)
277340
projected = encoded + self.feature_projections[modality](encoded)
278341
encoded_features[modality] = projected
279342

280343
if not encoded_features:
281344
raise ValueError("No valid features after encoding")
282345

283346
# Gating with residual pathways
347+
# g_m: x_m ⊕ g(x_m)
284348
if self.use_gating:
285349
gated_features = self.gating(encoded_features)
286350
for modality in encoded_features:
287351
gated_features[modality] = gated_features[modality] + encoded_features[modality]
288352
encoded_features = gated_features
289353

290354
# Attention with residual pathways
355+
# A_m: x_m ⊕ A(x_m)
291356
if self.use_cross_attention and len(encoded_features) > 1:
292357
attended_features = self.cross_attention(encoded_features, mask)
293358
for modality in encoded_features:

tests/models/multimodal/encoders/test_dynamic_encoder.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pvnet.models.multimodal.encoders.dynamic_encoder import DynamicFusionEncoder
1212

1313

14-
# Fixtures
14+
# Fixtures definition
1515
@pytest.fixture
1616
def minimal_config():
1717
""" Generate minimal config - basic functionality testing """
@@ -54,15 +54,15 @@ def minimal_config():
5454
'sat': feature_dim,
5555
'pv': feature_dim
5656
},
57-
'hidden_dim': feature_dim, # Changed to feature_dim
57+
'hidden_dim': feature_dim,
5858
'dropout': 0.1
5959
},
6060
'dynamic_fusion': {
6161
'feature_dims': {
6262
'sat': feature_dim,
6363
'pv': feature_dim
6464
},
65-
'hidden_dim': feature_dim, # Changed to feature_dim
65+
'hidden_dim': feature_dim,
6666
'num_heads': 4,
6767
'dropout': 0.1,
6868
'fusion_method': 'weighted_sum',
@@ -129,7 +129,7 @@ def test_single_modality(minimal_config, minimal_inputs):
129129
""" Test forward pass with single modality """
130130
encoder = create_encoder(minimal_config)
131131

132-
# Test with only satellite data
132+
# Test with only satellite data - update later when included in model
133133
with torch.no_grad():
134134
sat_only = {'sat': minimal_inputs['sat']}
135135
output_sat = encoder(sat_only)
@@ -157,7 +157,7 @@ def test_intermediate_shapes(minimal_config, minimal_inputs):
157157
def hook_fn(module, input, output):
158158
if isinstance(output, dict):
159159
for key, value in output.items():
160-
assert len(value.shape) == 3 # [batch, sequence, features]
160+
assert len(value.shape) == 3
161161
assert value.size(0) == batch_size
162162
assert value.size(1) == sequence_length
163163
assert value.size(2) == feature_dim
@@ -166,7 +166,6 @@ def hook_fn(module, input, output):
166166
assert output.size(0) == batch_size
167167
assert output.size(1) == sequence_length
168168

169-
# Register hooks
170169
if hasattr(encoder, 'gating'):
171170
encoder.gating.register_forward_hook(hook_fn)
172171
if hasattr(encoder, 'cross_attention'):
@@ -176,7 +175,7 @@ def hook_fn(module, input, output):
176175
encoder(minimal_inputs)
177176

178177

179-
# Robustness tests
178+
# Robustness testing
180179
@pytest.mark.parametrize("batch_size", [1, 4])
181180
def test_batch_sizes(minimal_config, minimal_inputs, batch_size):
182181
""" Test encoder behavior with different batch sizes """
@@ -250,7 +249,7 @@ def test_architecture_components(minimal_config):
250249

251250
encoder = create_encoder(minimal_config)
252251

253-
# Test encoder layers
252+
# Assert encoder layers
254253
assert hasattr(encoder, 'modality_encoders')
255254
assert hasattr(encoder, 'feature_projections')
256255
assert hasattr(encoder, 'fusion_module')
@@ -272,7 +271,6 @@ def hook(module, input, output):
272271
{k: v.shape for k, v in output.items()}
273272
return hook
274273

275-
# Register shape tracking hooks
276274
encoder.modality_encoders['sat'].register_forward_hook(hook_fn('sat_encoder'))
277275
encoder.feature_projections['sat'].register_forward_hook(hook_fn('sat_projection'))
278276
encoder.fusion_module.register_forward_hook(hook_fn('fusion'))
@@ -325,30 +323,8 @@ def attention_hook(module, input, output):
325323
encoder(minimal_inputs)
326324

327325
if attention_outputs:
326+
328327
# Verify attention weight distribution
329328
for modality, features in attention_outputs.items():
330329
std = features.std()
331330
assert std > 1e-6, "Attention weights too uniform"
332-
333-
334-
@pytest.mark.parametrize("noise_level", [0.1, 0.5, 1.0])
335-
def test_input_noise_robustness(minimal_config, minimal_inputs, noise_level):
336-
""" Test encoder stability under different noise levels """
337-
338-
encoder = create_encoder(minimal_config)
339-
340-
# Add noise to inputs
341-
noisy_inputs = {
342-
k: v + noise_level * torch.randn_like(v)
343-
for k, v in minimal_inputs.items()
344-
}
345-
346-
with torch.no_grad():
347-
clean_output = encoder(minimal_inputs)
348-
noisy_output = encoder(noisy_inputs)
349-
350-
# Check output stability
351-
relative_diff = (clean_output - noisy_output).abs().mean() / clean_output.abs().mean()
352-
assert not torch.isnan(relative_diff)
353-
assert not torch.isinf(relative_diff)
354-
assert relative_diff < noise_level * 10

0 commit comments

Comments
 (0)