Skip to content

Commit 38768f4

Browse files
committed
Dynamic encoder finalised
1 parent cc42d38 commit 38768f4

File tree

1 file changed

+155
-56
lines changed

1 file changed

+155
-56
lines changed
Lines changed: 155 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# dynamic_encoder.py
22

3-
""" Dynamic fusion encoder implementation for multimodal learning """
3+
"""
4+
Dynamic fusion encoder implementation for multimodal learning
45
6+
Defines PVEncoder, DynamicFusionEncoder and DynamicResidualEncoder
7+
"""
58

69
from typing import Dict, Optional, List, Union
710
import torch
@@ -13,16 +16,27 @@
1316
from pvnet.models.multimodal.encoders.encoders3d import DefaultPVNet2
1417

1518

19+
# Attention head compatibility function
20+
def get_compatible_heads(dim: int, target_heads: int) -> int:
21+
""" Calculate largest compatible number of heads <= target_heads """
22+
23+
for h in range(min(target_heads, dim), 0, -1):
24+
if dim % h == 0:
25+
return h
26+
return 1
27+
28+
29+
# Processes PV data maintaining temporal sequence
1630
class PVEncoder(nn.Module):
17-
""" Simplified PV encoder - maintains sequence dimension """
31+
""" PV specific encoder implementation with sequence preservation """
1832

1933
def __init__(self, sequence_length: int, num_sites: int, out_features: int):
2034
super().__init__()
2135
self.sequence_length = sequence_length
2236
self.num_sites = num_sites
2337
self.out_features = out_features
2438

25-
# Process each timestep independently
39+
# Basic feature extraction network
2640
self.encoder = nn.Sequential(
2741
nn.Linear(num_sites, out_features),
2842
nn.LayerNorm(out_features),
@@ -31,20 +45,18 @@ def __init__(self, sequence_length: int, num_sites: int, out_features: int):
3145
)
3246

3347
def forward(self, x):
34-
# x: [batch_size, sequence_length, num_sites]
48+
49+
# Sequential processing - maintain temporal order
3550
batch_size = x.shape[0]
36-
# Process each timestep
3751
out = []
3852
for t in range(self.sequence_length):
39-
out.append(self.encoder(x[:, t]))
40-
# Stack along sequence dimension
41-
return torch.stack(out, dim=1) # [batch_size, sequence_length, out_features]
53+
out.append(self.encoder(x[:, t]))\
54+
# Reshape maintaining sequence dimension
55+
return torch.stack(out, dim=1)
4256

4357

58+
# Primary fusion encoder implementation
4459
class DynamicFusionEncoder(AbstractNWPSatelliteEncoder):
45-
46-
"""Encoder that implements dynamic fusion of satellite/NWP data streams"""
47-
4860
def __init__(
4961
self,
5062
sequence_length: int,
@@ -62,85 +74,124 @@ def __init__(
6274
use_gating: bool = True,
6375
use_cross_attention: bool = True
6476
):
65-
"""Dynamic fusion encoder for multimodal satellite/NWP data."""
77+
""" Dynamic fusion encoder initialisation """
78+
6679
super().__init__(
6780
sequence_length=sequence_length,
6881
image_size_pixels=image_size_pixels,
6982
in_channels=sum(modality_channels.values()),
7083
out_features=out_features
7184
)
85+
86+
# Dimension validation and compatibility
87+
if hidden_dim % sequence_length != 0:
88+
feature_dim = ((hidden_dim + sequence_length - 1) // sequence_length)
89+
hidden_dim = feature_dim * sequence_length
90+
else:
91+
feature_dim = hidden_dim // sequence_length
92+
93+
# Attention mechanism setup
94+
attention_heads = cross_attention.get('num_heads', num_heads)
95+
attention_heads = get_compatible_heads(feature_dim, attention_heads)
7296

73-
self.modalities = list(modality_channels.keys())
97+
# Feature dimension adjustment for attention
98+
if feature_dim < attention_heads:
99+
feature_dim = attention_heads
100+
hidden_dim = feature_dim * sequence_length
101+
elif feature_dim % attention_heads != 0:
102+
feature_dim = ((feature_dim + attention_heads - 1) // attention_heads) * attention_heads
103+
hidden_dim = feature_dim * sequence_length
104+
105+
# Architecture dimensions
106+
self.feature_dim = feature_dim
74107
self.hidden_dim = hidden_dim
75108
self.sequence_length = sequence_length
109+
self.modalities = list(modality_channels.keys())
76110

77-
# Initialize modality-specific encoders
111+
# Update configs with validated dimensions
112+
cross_attention['num_heads'] = attention_heads
113+
dynamic_fusion['num_heads'] = attention_heads
114+
115+
# Modality specific encoder instantiation
78116
self.modality_encoders = nn.ModuleDict()
79117
for modality, config in modality_encoders.items():
80118
config = config.copy()
81119
if 'nwp' in modality or 'sat' in modality:
120+
121+
# Image based modality encoder
82122
encoder = DefaultPVNet2(
83123
sequence_length=sequence_length,
84124
image_size_pixels=config.get('image_size_pixels', image_size_pixels),
85125
in_channels=modality_channels[modality],
86-
out_features=config.get('out_features', hidden_dim),
126+
out_features=hidden_dim,
87127
number_of_conv3d_layers=config.get('number_of_conv3d_layers', 4),
88128
conv3d_channels=config.get('conv3d_channels', 32),
89129
batch_norm=config.get('batch_norm', True),
90-
fc_dropout=config.get('fc_dropout', 0.2)
130+
fc_dropout=dropout
91131
)
92-
132+
93133
self.modality_encoders[modality] = nn.Sequential(
94134
encoder,
95-
nn.Unflatten(1, (sequence_length, hidden_dim//sequence_length))
135+
nn.Linear(hidden_dim, sequence_length * feature_dim),
136+
nn.Unflatten(-1, (sequence_length, feature_dim))
96137
)
97-
98138
elif modality == 'pv':
139+
140+
# PV specific encoder
99141
self.modality_encoders[modality] = PVEncoder(
100142
sequence_length=sequence_length,
101143
num_sites=config['num_sites'],
102-
out_features=hidden_dim
144+
out_features=feature_dim
103145
)
104146

105-
# Feature projections
147+
# Feature transformation layers
106148
self.feature_projections = nn.ModuleDict({
107149
modality: nn.Sequential(
108-
nn.Linear(hidden_dim, hidden_dim),
109-
nn.LayerNorm(hidden_dim),
150+
nn.LayerNorm(feature_dim),
151+
nn.Linear(feature_dim, feature_dim),
110152
nn.ReLU(),
111153
nn.Dropout(dropout)
112154
)
113155
for modality in modality_channels.keys()
114156
})
115157

116-
# Optional modality gating
158+
# Modality gating mechanism
117159
self.use_gating = use_gating
118160
if use_gating:
119161
gating_config = modality_gating.copy()
120-
gating_config['feature_dims'] = {
121-
mod: hidden_dim for mod in modality_channels.keys()
122-
}
162+
gating_config.update({
163+
'feature_dims': {mod: feature_dim for mod in modality_channels.keys()},
164+
'hidden_dim': feature_dim
165+
})
123166
self.gating = ModalityGating(**gating_config)
124167

125-
# Optional cross-modal attention
168+
# Cross modal attention mechanism
126169
self.use_cross_attention = use_cross_attention
127170
if use_cross_attention:
128171
attention_config = cross_attention.copy()
129-
attention_config['embed_dim'] = hidden_dim
172+
attention_config.update({
173+
'embed_dim': feature_dim,
174+
'num_heads': attention_heads,
175+
'dropout': dropout
176+
})
130177
self.cross_attention = CrossModalAttention(**attention_config)
131178

132-
# Dynamic fusion module
179+
# Dynamic fusion implementation
133180
fusion_config = dynamic_fusion.copy()
134-
fusion_config['feature_dims'] = {
135-
mod: hidden_dim for mod in modality_channels.keys()
136-
}
137-
fusion_config['hidden_dim'] = hidden_dim
181+
fusion_config.update({
182+
'feature_dims': {mod: feature_dim for mod in modality_channels.keys()},
183+
'hidden_dim': feature_dim,
184+
'num_heads': attention_heads,
185+
'dropout': dropout
186+
})
138187
self.fusion_module = DynamicFusionModule(**fusion_config)
139188

140-
# Final output projection
189+
# Output network definition
141190
self.final_block = nn.Sequential(
142-
nn.Linear(hidden_dim * sequence_length, fc_features),
191+
nn.Linear(hidden_dim, fc_features),
192+
nn.LayerNorm(fc_features),
143193
nn.ELU(),
194+
nn.Dropout(dropout),
144195
nn.Linear(fc_features, out_features),
145196
nn.ELU(),
146197
)
@@ -150,54 +201,102 @@ def forward(
150201
inputs: Dict[str, torch.Tensor],
151202
mask: Optional[torch.Tensor] = None
152203
) -> torch.Tensor:
153-
"""Forward pass of the dynamic fusion encoder"""
154-
# Initial encoding of each modality
204+
205+
""" Dynamic fusion forward pass implementation """
206+
155207
encoded_features = {}
208+
209+
# Modality specific encoding
156210
for modality, x in inputs.items():
157-
if modality not in self.modality_encoders:
211+
if modality not in self.modality_encoders or x is None:
158212
continue
213+
214+
# Feature extraction and projection
215+
encoded = self.modality_encoders[modality](x)
216+
projected = torch.stack([
217+
self.feature_projections[modality](encoded[:, t])
218+
for t in range(self.sequence_length)
219+
], dim=1)
159220

160-
# Apply modality-specific encoder
161-
# Output shape: [batch_size, sequence_length, hidden_dim]
162-
encoded_features[modality] = self.modality_encoders[modality](x)
221+
encoded_features[modality] = projected
163222

164223
if not encoded_features:
165-
raise ValueError("No valid features found in inputs")
224+
raise ValueError("No valid features after encoding")
166225

167-
# Apply modality gating if enabled
226+
# Apply modality interaction mechanisms
168227
if self.use_gating:
169228
encoded_features = self.gating(encoded_features)
170229

171-
# Apply cross-modal attention if enabled and more than one modality
172230
if self.use_cross_attention and len(encoded_features) > 1:
173231
encoded_features = self.cross_attention(encoded_features, mask)
174232

175-
# Apply dynamic fusion
176-
fused_features = self.fusion_module(encoded_features, mask) # [batch, sequence, hidden]
177-
178-
# Reshape and apply final projection
233+
# Feature fusion and output generation
234+
fused_features = self.fusion_module(encoded_features, mask)
179235
batch_size = fused_features.size(0)
180-
fused_features = fused_features.reshape(batch_size, -1) # Flatten sequence dimension
236+
fused_features = fused_features.repeat(1, self.sequence_length)
181237
output = self.final_block(fused_features)
182238

183239
return output
184240

185241

186242
class DynamicResidualEncoder(DynamicFusionEncoder):
187-
"""Dynamic fusion encoder with residual connections"""
243+
""" Dynamic fusion implementation with residual connectivity """
188244

189245
def __init__(self, *args, **kwargs):
190246
super().__init__(*args, **kwargs)
191247

192-
# Override feature projections to include residual connections
248+
# Enhanced projection with residual pathways
193249
self.feature_projections = nn.ModuleDict({
194250
modality: nn.Sequential(
195-
nn.Linear(self.hidden_dim, self.hidden_dim),
196251
nn.LayerNorm(self.hidden_dim),
252+
nn.Linear(self.hidden_dim, self.hidden_dim * 2),
197253
nn.ReLU(),
198254
nn.Dropout(kwargs.get('dropout', 0.1)),
199-
nn.Linear(self.hidden_dim, self.hidden_dim),
200-
nn.LayerNorm(self.hidden_dim)
255+
nn.Linear(self.hidden_dim * 2, self.hidden_dim),
256+
nn.LayerNorm(self.hidden_dim),
201257
)
202258
for modality in kwargs['modality_channels'].keys()
203-
})
259+
})
260+
261+
def forward(
262+
self,
263+
inputs: Dict[str, torch.Tensor],
264+
mask: Optional[torch.Tensor] = None
265+
) -> torch.Tensor:
266+
267+
""" Forward implementation with residual pathways """
268+
269+
encoded_features = {}
270+
271+
# Feature extraction with residual connections
272+
for modality, x in inputs.items():
273+
if modality not in self.modality_encoders or x is None:
274+
continue
275+
276+
encoded = self.modality_encoders[modality](x)
277+
projected = encoded + self.feature_projections[modality](encoded)
278+
encoded_features[modality] = projected
279+
280+
if not encoded_features:
281+
raise ValueError("No valid features after encoding")
282+
283+
# Gating with residual pathways
284+
if self.use_gating:
285+
gated_features = self.gating(encoded_features)
286+
for modality in encoded_features:
287+
gated_features[modality] = gated_features[modality] + encoded_features[modality]
288+
encoded_features = gated_features
289+
290+
# Attention with residual pathways
291+
if self.use_cross_attention and len(encoded_features) > 1:
292+
attended_features = self.cross_attention(encoded_features, mask)
293+
for modality in encoded_features:
294+
attended_features[modality] = attended_features[modality] + encoded_features[modality]
295+
encoded_features = attended_features
296+
297+
# Final fusion and output generation
298+
fused_features = self.fusion_module(encoded_features, mask)
299+
fused_features = fused_features.repeat(1, self.sequence_length)
300+
output = self.final_block(fused_features)
301+
302+
return output

0 commit comments

Comments
 (0)