Skip to content

Commit 09b5ff9

Browse files
committed
Fix
1 parent 0d09830 commit 09b5ff9

File tree

2 files changed

+309
-0
lines changed

2 files changed

+309
-0
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# dynamic_encoder.py
2+
3+
""" Dynamic fusion encoder implementation for multimodal learning """
4+
5+
6+
from typing import Dict, Optional, List, Union
7+
import torch
8+
from torch import nn
9+
10+
from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder
11+
from pvnet.models.multimodal.fusion_blocks import DynamicFusionModule, ModalityGating
12+
from pvnet.models.multimodal.attention_blocks import CrossModalAttention, SelfAttention
13+
from pvnet.models.multimodal.encoders.encoders3d import DefaultPVNet2
14+
15+
16+
class PVEncoder(nn.Module):
17+
""" Simplified PV encoder - maintains sequence dimension """
18+
19+
def __init__(self, sequence_length: int, num_sites: int, out_features: int):
20+
super().__init__()
21+
self.sequence_length = sequence_length
22+
self.num_sites = num_sites
23+
self.out_features = out_features
24+
25+
# Process each timestep independently
26+
self.encoder = nn.Sequential(
27+
nn.Linear(num_sites, out_features),
28+
nn.LayerNorm(out_features),
29+
nn.ReLU(),
30+
nn.Dropout(0.1)
31+
)
32+
33+
def forward(self, x):
34+
# x: [batch_size, sequence_length, num_sites]
35+
batch_size = x.shape[0]
36+
# Process each timestep
37+
out = []
38+
for t in range(self.sequence_length):
39+
out.append(self.encoder(x[:, t]))
40+
# Stack along sequence dimension
41+
return torch.stack(out, dim=1) # [batch_size, sequence_length, out_features]
42+
43+
44+
class DynamicFusionEncoder(AbstractNWPSatelliteEncoder):
45+
46+
"""Encoder that implements dynamic fusion of satellite/NWP data streams"""
47+
48+
def __init__(
49+
self,
50+
sequence_length: int,
51+
image_size_pixels: int,
52+
modality_channels: Dict[str, int],
53+
out_features: int,
54+
modality_encoders: Dict[str, dict],
55+
cross_attention: Dict,
56+
modality_gating: Dict,
57+
dynamic_fusion: Dict,
58+
hidden_dim: int = 256,
59+
fc_features: int = 128,
60+
num_heads: int = 8,
61+
dropout: float = 0.1,
62+
use_gating: bool = True,
63+
use_cross_attention: bool = True
64+
):
65+
"""Dynamic fusion encoder for multimodal satellite/NWP data."""
66+
super().__init__(
67+
sequence_length=sequence_length,
68+
image_size_pixels=image_size_pixels,
69+
in_channels=sum(modality_channels.values()),
70+
out_features=out_features
71+
)
72+
73+
self.modalities = list(modality_channels.keys())
74+
self.hidden_dim = hidden_dim
75+
self.sequence_length = sequence_length
76+
77+
# Initialize modality-specific encoders
78+
self.modality_encoders = nn.ModuleDict()
79+
for modality, config in modality_encoders.items():
80+
config = config.copy()
81+
if 'nwp' in modality or 'sat' in modality:
82+
encoder = DefaultPVNet2(
83+
sequence_length=sequence_length,
84+
image_size_pixels=config.get('image_size_pixels', image_size_pixels),
85+
in_channels=modality_channels[modality],
86+
out_features=config.get('out_features', hidden_dim),
87+
number_of_conv3d_layers=config.get('number_of_conv3d_layers', 4),
88+
conv3d_channels=config.get('conv3d_channels', 32),
89+
batch_norm=config.get('batch_norm', True),
90+
fc_dropout=config.get('fc_dropout', 0.2)
91+
)
92+
93+
self.modality_encoders[modality] = nn.Sequential(
94+
encoder,
95+
nn.Unflatten(1, (sequence_length, hidden_dim//sequence_length))
96+
)
97+
98+
elif modality == 'pv':
99+
self.modality_encoders[modality] = PVEncoder(
100+
sequence_length=sequence_length,
101+
num_sites=config['num_sites'],
102+
out_features=hidden_dim
103+
)
104+
105+
# Feature projections
106+
self.feature_projections = nn.ModuleDict({
107+
modality: nn.Sequential(
108+
nn.Linear(hidden_dim, hidden_dim),
109+
nn.LayerNorm(hidden_dim),
110+
nn.ReLU(),
111+
nn.Dropout(dropout)
112+
)
113+
for modality in modality_channels.keys()
114+
})
115+
116+
# Optional modality gating
117+
self.use_gating = use_gating
118+
if use_gating:
119+
gating_config = modality_gating.copy()
120+
gating_config['feature_dims'] = {
121+
mod: hidden_dim for mod in modality_channels.keys()
122+
}
123+
self.gating = ModalityGating(**gating_config)
124+
125+
# Optional cross-modal attention
126+
self.use_cross_attention = use_cross_attention
127+
if use_cross_attention:
128+
attention_config = cross_attention.copy()
129+
attention_config['embed_dim'] = hidden_dim
130+
self.cross_attention = CrossModalAttention(**attention_config)
131+
132+
# Dynamic fusion module
133+
fusion_config = dynamic_fusion.copy()
134+
fusion_config['feature_dims'] = {
135+
mod: hidden_dim for mod in modality_channels.keys()
136+
}
137+
fusion_config['hidden_dim'] = hidden_dim
138+
self.fusion_module = DynamicFusionModule(**fusion_config)
139+
140+
# Final output projection
141+
self.final_block = nn.Sequential(
142+
nn.Linear(hidden_dim * sequence_length, fc_features),
143+
nn.ELU(),
144+
nn.Linear(fc_features, out_features),
145+
nn.ELU(),
146+
)
147+
148+
def forward(
149+
self,
150+
inputs: Dict[str, torch.Tensor],
151+
mask: Optional[torch.Tensor] = None
152+
) -> torch.Tensor:
153+
"""Forward pass of the dynamic fusion encoder"""
154+
# Initial encoding of each modality
155+
encoded_features = {}
156+
for modality, x in inputs.items():
157+
if modality not in self.modality_encoders:
158+
continue
159+
160+
# Apply modality-specific encoder
161+
# Output shape: [batch_size, sequence_length, hidden_dim]
162+
encoded_features[modality] = self.modality_encoders[modality](x)
163+
164+
if not encoded_features:
165+
raise ValueError("No valid features found in inputs")
166+
167+
# Apply modality gating if enabled
168+
if self.use_gating:
169+
encoded_features = self.gating(encoded_features)
170+
171+
# Apply cross-modal attention if enabled and more than one modality
172+
if self.use_cross_attention and len(encoded_features) > 1:
173+
encoded_features = self.cross_attention(encoded_features, mask)
174+
175+
# Apply dynamic fusion
176+
fused_features = self.fusion_module(encoded_features, mask) # [batch, sequence, hidden]
177+
178+
# Reshape and apply final projection
179+
batch_size = fused_features.size(0)
180+
fused_features = fused_features.reshape(batch_size, -1) # Flatten sequence dimension
181+
output = self.final_block(fused_features)
182+
183+
return output
184+
185+
186+
class DynamicResidualEncoder(DynamicFusionEncoder):
187+
"""Dynamic fusion encoder with residual connections"""
188+
189+
def __init__(self, *args, **kwargs):
190+
super().__init__(*args, **kwargs)
191+
192+
# Override feature projections to include residual connections
193+
self.feature_projections = nn.ModuleDict({
194+
modality: nn.Sequential(
195+
nn.Linear(self.hidden_dim, self.hidden_dim),
196+
nn.LayerNorm(self.hidden_dim),
197+
nn.ReLU(),
198+
nn.Dropout(kwargs.get('dropout', 0.1)),
199+
nn.Linear(self.hidden_dim, self.hidden_dim),
200+
nn.LayerNorm(self.hidden_dim)
201+
)
202+
for modality in kwargs['modality_channels'].keys()
203+
})
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import pytest
2+
import torch
3+
from typing import Dict
4+
5+
from pvnet.models.multimodal.encoders.dynamic_encoder import DynamicFusionEncoder
6+
7+
@pytest.fixture
8+
def minimal_config():
9+
"""Minimal configuration for testing basic functionality"""
10+
sequence_length = 12
11+
hidden_dim = 60 # Chosen so it divides evenly by sequence_length (60/12 = 5)
12+
13+
# Important: feature_dim needs to match between modalities
14+
feature_dim = hidden_dim // sequence_length # This is 5
15+
16+
return {
17+
'sequence_length': sequence_length,
18+
'image_size_pixels': 24,
19+
'modality_channels': {
20+
'sat': 2,
21+
'pv': 10
22+
},
23+
'out_features': 32,
24+
'hidden_dim': hidden_dim,
25+
'fc_features': 32,
26+
'modality_encoders': {
27+
'sat': {
28+
'image_size_pixels': 24,
29+
'out_features': feature_dim * sequence_length, # 60
30+
'number_of_conv3d_layers': 2,
31+
'conv3d_channels': 16,
32+
'batch_norm': True,
33+
'fc_dropout': 0.1
34+
},
35+
'pv': {
36+
'num_sites': 10,
37+
'out_features': feature_dim # 5 - this ensures proper dimension
38+
}
39+
},
40+
'cross_attention': {
41+
'embed_dim': hidden_dim,
42+
'num_heads': 4,
43+
'dropout': 0.1,
44+
'num_modalities': 2
45+
},
46+
'modality_gating': {
47+
'feature_dims': {
48+
'sat': hidden_dim,
49+
'pv': hidden_dim
50+
},
51+
'hidden_dim': hidden_dim,
52+
'dropout': 0.1
53+
},
54+
'dynamic_fusion': {
55+
'feature_dims': {
56+
'sat': hidden_dim,
57+
'pv': hidden_dim
58+
},
59+
'hidden_dim': hidden_dim,
60+
'num_heads': 4,
61+
'dropout': 0.1,
62+
'fusion_method': 'weighted_sum',
63+
'use_residual': True
64+
}
65+
}
66+
67+
@pytest.fixture
68+
def minimal_inputs(minimal_config):
69+
"""Generate minimal test inputs"""
70+
batch_size = 2
71+
sequence_length = minimal_config['sequence_length']
72+
73+
return {
74+
'sat': torch.randn(batch_size, 2, sequence_length, 24, 24),
75+
'pv': torch.randn(batch_size, sequence_length, 10)
76+
}
77+
78+
def test_batch_sizes(self, minimal_config, minimal_inputs, batch_size):
79+
"""Test different batch sizes"""
80+
encoder = DynamicFusionEncoder(
81+
sequence_length=minimal_config['sequence_length'],
82+
image_size_pixels=minimal_config['image_size_pixels'],
83+
modality_channels=minimal_config['modality_channels'],
84+
out_features=minimal_config['out_features'],
85+
modality_encoders=minimal_config['modality_encoders'],
86+
cross_attention=minimal_config['cross_attention'],
87+
modality_gating=minimal_config['modality_gating'],
88+
dynamic_fusion=minimal_config['dynamic_fusion'],
89+
hidden_dim=minimal_config['hidden_dim'],
90+
fc_features=minimal_config['fc_features']
91+
)
92+
93+
# Adjust input batch sizes - fixed repeat logic
94+
adjusted_inputs = {}
95+
for k, v in minimal_inputs.items():
96+
if batch_size < v.size(0):
97+
adjusted_inputs[k] = v[:batch_size]
98+
else:
99+
repeat_factor = batch_size // v.size(0)
100+
adjusted_inputs[k] = v.repeat(repeat_factor, *[1]*(len(v.shape)-1))
101+
102+
with torch.no_grad():
103+
output = encoder(adjusted_inputs)
104+
105+
assert output.shape == (batch_size, minimal_config['out_features'])
106+
assert not torch.isnan(output).any()

0 commit comments

Comments
 (0)