1
1
# dynamic_encoder.py
2
2
3
- """ Dynamic fusion encoder implementation for multimodal learning """
3
+ """
4
+ Dynamic fusion encoder implementation for multimodal learning
4
5
6
+ Defines PVEncoder, DynamicFusionEncoder and DynamicResidualEncoder
7
+ """
5
8
6
9
from typing import Dict , Optional , List , Union
7
10
import torch
13
16
from pvnet .models .multimodal .encoders .encoders3d import DefaultPVNet2
14
17
15
18
19
+ # Attention head compatibility function
20
+ def get_compatible_heads (dim : int , target_heads : int ) -> int :
21
+ """ Calculate largest compatible number of heads <= target_heads """
22
+
23
+ for h in range (min (target_heads , dim ), 0 , - 1 ):
24
+ if dim % h == 0 :
25
+ return h
26
+ return 1
27
+
28
+
29
+ # Processes PV data maintaining temporal sequence
16
30
class PVEncoder (nn .Module ):
17
- """ Simplified PV encoder - maintains sequence dimension """
31
+ """ PV specific encoder implementation with sequence preservation """
18
32
19
33
def __init__ (self , sequence_length : int , num_sites : int , out_features : int ):
20
34
super ().__init__ ()
21
35
self .sequence_length = sequence_length
22
36
self .num_sites = num_sites
23
37
self .out_features = out_features
24
38
25
- # Process each timestep independently
39
+ # Basic feature extraction network
26
40
self .encoder = nn .Sequential (
27
41
nn .Linear (num_sites , out_features ),
28
42
nn .LayerNorm (out_features ),
@@ -31,20 +45,18 @@ def __init__(self, sequence_length: int, num_sites: int, out_features: int):
31
45
)
32
46
33
47
def forward (self , x ):
34
- # x: [batch_size, sequence_length, num_sites]
48
+
49
+ # Sequential processing - maintain temporal order
35
50
batch_size = x .shape [0 ]
36
- # Process each timestep
37
51
out = []
38
52
for t in range (self .sequence_length ):
39
- out .append (self .encoder (x [:, t ]))
40
- # Stack along sequence dimension
41
- return torch .stack (out , dim = 1 ) # [batch_size, sequence_length, out_features]
53
+ out .append (self .encoder (x [:, t ]))\
54
+ # Reshape maintaining sequence dimension
55
+ return torch .stack (out , dim = 1 )
42
56
43
57
58
+ # Primary fusion encoder implementation
44
59
class DynamicFusionEncoder (AbstractNWPSatelliteEncoder ):
45
-
46
- """Encoder that implements dynamic fusion of satellite/NWP data streams"""
47
-
48
60
def __init__ (
49
61
self ,
50
62
sequence_length : int ,
@@ -62,85 +74,124 @@ def __init__(
62
74
use_gating : bool = True ,
63
75
use_cross_attention : bool = True
64
76
):
65
- """Dynamic fusion encoder for multimodal satellite/NWP data."""
77
+ """ Dynamic fusion encoder initialisation """
78
+
66
79
super ().__init__ (
67
80
sequence_length = sequence_length ,
68
81
image_size_pixels = image_size_pixels ,
69
82
in_channels = sum (modality_channels .values ()),
70
83
out_features = out_features
71
84
)
85
+
86
+ # Dimension validation and compatibility
87
+ if hidden_dim % sequence_length != 0 :
88
+ feature_dim = ((hidden_dim + sequence_length - 1 ) // sequence_length )
89
+ hidden_dim = feature_dim * sequence_length
90
+ else :
91
+ feature_dim = hidden_dim // sequence_length
92
+
93
+ # Attention mechanism setup
94
+ attention_heads = cross_attention .get ('num_heads' , num_heads )
95
+ attention_heads = get_compatible_heads (feature_dim , attention_heads )
72
96
73
- self .modalities = list (modality_channels .keys ())
97
+ # Feature dimension adjustment for attention
98
+ if feature_dim < attention_heads :
99
+ feature_dim = attention_heads
100
+ hidden_dim = feature_dim * sequence_length
101
+ elif feature_dim % attention_heads != 0 :
102
+ feature_dim = ((feature_dim + attention_heads - 1 ) // attention_heads ) * attention_heads
103
+ hidden_dim = feature_dim * sequence_length
104
+
105
+ # Architecture dimensions
106
+ self .feature_dim = feature_dim
74
107
self .hidden_dim = hidden_dim
75
108
self .sequence_length = sequence_length
109
+ self .modalities = list (modality_channels .keys ())
76
110
77
- # Initialize modality-specific encoders
111
+ # Update configs with validated dimensions
112
+ cross_attention ['num_heads' ] = attention_heads
113
+ dynamic_fusion ['num_heads' ] = attention_heads
114
+
115
+ # Modality specific encoder instantiation
78
116
self .modality_encoders = nn .ModuleDict ()
79
117
for modality , config in modality_encoders .items ():
80
118
config = config .copy ()
81
119
if 'nwp' in modality or 'sat' in modality :
120
+
121
+ # Image based modality encoder
82
122
encoder = DefaultPVNet2 (
83
123
sequence_length = sequence_length ,
84
124
image_size_pixels = config .get ('image_size_pixels' , image_size_pixels ),
85
125
in_channels = modality_channels [modality ],
86
- out_features = config . get ( 'out_features' , hidden_dim ) ,
126
+ out_features = hidden_dim ,
87
127
number_of_conv3d_layers = config .get ('number_of_conv3d_layers' , 4 ),
88
128
conv3d_channels = config .get ('conv3d_channels' , 32 ),
89
129
batch_norm = config .get ('batch_norm' , True ),
90
- fc_dropout = config . get ( 'fc_dropout' , 0.2 )
130
+ fc_dropout = dropout
91
131
)
92
-
132
+
93
133
self .modality_encoders [modality ] = nn .Sequential (
94
134
encoder ,
95
- nn .Unflatten (1 , (sequence_length , hidden_dim // sequence_length ))
135
+ nn .Linear (hidden_dim , sequence_length * feature_dim ),
136
+ nn .Unflatten (- 1 , (sequence_length , feature_dim ))
96
137
)
97
-
98
138
elif modality == 'pv' :
139
+
140
+ # PV specific encoder
99
141
self .modality_encoders [modality ] = PVEncoder (
100
142
sequence_length = sequence_length ,
101
143
num_sites = config ['num_sites' ],
102
- out_features = hidden_dim
144
+ out_features = feature_dim
103
145
)
104
146
105
- # Feature projections
147
+ # Feature transformation layers
106
148
self .feature_projections = nn .ModuleDict ({
107
149
modality : nn .Sequential (
108
- nn .Linear ( hidden_dim , hidden_dim ),
109
- nn .LayerNorm ( hidden_dim ),
150
+ nn .LayerNorm ( feature_dim ),
151
+ nn .Linear ( feature_dim , feature_dim ),
110
152
nn .ReLU (),
111
153
nn .Dropout (dropout )
112
154
)
113
155
for modality in modality_channels .keys ()
114
156
})
115
157
116
- # Optional modality gating
158
+ # Modality gating mechanism
117
159
self .use_gating = use_gating
118
160
if use_gating :
119
161
gating_config = modality_gating .copy ()
120
- gating_config ['feature_dims' ] = {
121
- mod : hidden_dim for mod in modality_channels .keys ()
122
- }
162
+ gating_config .update ({
163
+ 'feature_dims' : {mod : feature_dim for mod in modality_channels .keys ()},
164
+ 'hidden_dim' : feature_dim
165
+ })
123
166
self .gating = ModalityGating (** gating_config )
124
167
125
- # Optional cross- modal attention
168
+ # Cross modal attention mechanism
126
169
self .use_cross_attention = use_cross_attention
127
170
if use_cross_attention :
128
171
attention_config = cross_attention .copy ()
129
- attention_config ['embed_dim' ] = hidden_dim
172
+ attention_config .update ({
173
+ 'embed_dim' : feature_dim ,
174
+ 'num_heads' : attention_heads ,
175
+ 'dropout' : dropout
176
+ })
130
177
self .cross_attention = CrossModalAttention (** attention_config )
131
178
132
- # Dynamic fusion module
179
+ # Dynamic fusion implementation
133
180
fusion_config = dynamic_fusion .copy ()
134
- fusion_config ['feature_dims' ] = {
135
- mod : hidden_dim for mod in modality_channels .keys ()
136
- }
137
- fusion_config ['hidden_dim' ] = hidden_dim
181
+ fusion_config .update ({
182
+ 'feature_dims' : {mod : feature_dim for mod in modality_channels .keys ()},
183
+ 'hidden_dim' : feature_dim ,
184
+ 'num_heads' : attention_heads ,
185
+ 'dropout' : dropout
186
+ })
138
187
self .fusion_module = DynamicFusionModule (** fusion_config )
139
188
140
- # Final output projection
189
+ # Output network definition
141
190
self .final_block = nn .Sequential (
142
- nn .Linear (hidden_dim * sequence_length , fc_features ),
191
+ nn .Linear (hidden_dim , fc_features ),
192
+ nn .LayerNorm (fc_features ),
143
193
nn .ELU (),
194
+ nn .Dropout (dropout ),
144
195
nn .Linear (fc_features , out_features ),
145
196
nn .ELU (),
146
197
)
@@ -150,54 +201,102 @@ def forward(
150
201
inputs : Dict [str , torch .Tensor ],
151
202
mask : Optional [torch .Tensor ] = None
152
203
) -> torch .Tensor :
153
- """Forward pass of the dynamic fusion encoder"""
154
- # Initial encoding of each modality
204
+
205
+ """ Dynamic fusion forward pass implementation """
206
+
155
207
encoded_features = {}
208
+
209
+ # Modality specific encoding
156
210
for modality , x in inputs .items ():
157
- if modality not in self .modality_encoders :
211
+ if modality not in self .modality_encoders or x is None :
158
212
continue
213
+
214
+ # Feature extraction and projection
215
+ encoded = self .modality_encoders [modality ](x )
216
+ projected = torch .stack ([
217
+ self .feature_projections [modality ](encoded [:, t ])
218
+ for t in range (self .sequence_length )
219
+ ], dim = 1 )
159
220
160
- # Apply modality-specific encoder
161
- # Output shape: [batch_size, sequence_length, hidden_dim]
162
- encoded_features [modality ] = self .modality_encoders [modality ](x )
221
+ encoded_features [modality ] = projected
163
222
164
223
if not encoded_features :
165
- raise ValueError ("No valid features found in inputs " )
224
+ raise ValueError ("No valid features after encoding " )
166
225
167
- # Apply modality gating if enabled
226
+ # Apply modality interaction mechanisms
168
227
if self .use_gating :
169
228
encoded_features = self .gating (encoded_features )
170
229
171
- # Apply cross-modal attention if enabled and more than one modality
172
230
if self .use_cross_attention and len (encoded_features ) > 1 :
173
231
encoded_features = self .cross_attention (encoded_features , mask )
174
232
175
- # Apply dynamic fusion
176
- fused_features = self .fusion_module (encoded_features , mask ) # [batch, sequence, hidden]
177
-
178
- # Reshape and apply final projection
233
+ # Feature fusion and output generation
234
+ fused_features = self .fusion_module (encoded_features , mask )
179
235
batch_size = fused_features .size (0 )
180
- fused_features = fused_features .reshape ( batch_size , - 1 ) # Flatten sequence dimension
236
+ fused_features = fused_features .repeat ( 1 , self . sequence_length )
181
237
output = self .final_block (fused_features )
182
238
183
239
return output
184
240
185
241
186
242
class DynamicResidualEncoder (DynamicFusionEncoder ):
187
- """Dynamic fusion encoder with residual connections """
243
+ """ Dynamic fusion implementation with residual connectivity """
188
244
189
245
def __init__ (self , * args , ** kwargs ):
190
246
super ().__init__ (* args , ** kwargs )
191
247
192
- # Override feature projections to include residual connections
248
+ # Enhanced projection with residual pathways
193
249
self .feature_projections = nn .ModuleDict ({
194
250
modality : nn .Sequential (
195
- nn .Linear (self .hidden_dim , self .hidden_dim ),
196
251
nn .LayerNorm (self .hidden_dim ),
252
+ nn .Linear (self .hidden_dim , self .hidden_dim * 2 ),
197
253
nn .ReLU (),
198
254
nn .Dropout (kwargs .get ('dropout' , 0.1 )),
199
- nn .Linear (self .hidden_dim , self .hidden_dim ),
200
- nn .LayerNorm (self .hidden_dim )
255
+ nn .Linear (self .hidden_dim * 2 , self .hidden_dim ),
256
+ nn .LayerNorm (self .hidden_dim ),
201
257
)
202
258
for modality in kwargs ['modality_channels' ].keys ()
203
- })
259
+ })
260
+
261
+ def forward (
262
+ self ,
263
+ inputs : Dict [str , torch .Tensor ],
264
+ mask : Optional [torch .Tensor ] = None
265
+ ) -> torch .Tensor :
266
+
267
+ """ Forward implementation with residual pathways """
268
+
269
+ encoded_features = {}
270
+
271
+ # Feature extraction with residual connections
272
+ for modality , x in inputs .items ():
273
+ if modality not in self .modality_encoders or x is None :
274
+ continue
275
+
276
+ encoded = self .modality_encoders [modality ](x )
277
+ projected = encoded + self .feature_projections [modality ](encoded )
278
+ encoded_features [modality ] = projected
279
+
280
+ if not encoded_features :
281
+ raise ValueError ("No valid features after encoding" )
282
+
283
+ # Gating with residual pathways
284
+ if self .use_gating :
285
+ gated_features = self .gating (encoded_features )
286
+ for modality in encoded_features :
287
+ gated_features [modality ] = gated_features [modality ] + encoded_features [modality ]
288
+ encoded_features = gated_features
289
+
290
+ # Attention with residual pathways
291
+ if self .use_cross_attention and len (encoded_features ) > 1 :
292
+ attended_features = self .cross_attention (encoded_features , mask )
293
+ for modality in encoded_features :
294
+ attended_features [modality ] = attended_features [modality ] + encoded_features [modality ]
295
+ encoded_features = attended_features
296
+
297
+ # Final fusion and output generation
298
+ fused_features = self .fusion_module (encoded_features , mask )
299
+ fused_features = fused_features .repeat (1 , self .sequence_length )
300
+ output = self .final_block (fused_features )
301
+
302
+ return output
0 commit comments