1
1
# multimodal_dynamic.py
2
2
3
- from collections import OrderedDict
4
- from typing import Optional , Dict , List , Tuple , Any , Union
5
- import logging
6
3
4
+ """
5
+ Dynamic multimodal fusion model implementation
6
+
7
+ Model class for multimodal fusion architecture - integrates multiple modality encoders / fusion mechanisms
8
+
9
+ Implementation permits dynamic fusion through attention-based mechanisms and modality-specific processing stages
10
+ """
11
+
12
+ import logging
13
+ import pvnet
7
14
import torch
15
+
8
16
from torch import nn
9
17
from ocf_datapipes .batch import BatchKey , NWPBatchKey
10
18
from omegaconf import DictConfig
19
+ from collections import OrderedDict
20
+ from typing import Optional , Dict , List , Tuple , Any , Union
11
21
12
- import pvnet
13
22
from pvnet .models .multimodal .basic_blocks import ImageEmbedding
14
23
from pvnet .models .multimodal .encoders .dynamic_encoder import DynamicFusionEncoder
15
24
from pvnet .models .multimodal .linear_networks .basic_blocks import AbstractLinearNetwork
16
25
from pvnet .models .multimodal .site_encoders .basic_blocks import AbstractPVSitesEncoder
17
26
from pvnet .models .multimodal .multimodal_base import MultimodalBaseModel
18
27
from pvnet .optimizers import AbstractOptimizer
19
28
29
+
20
30
logger = logging .getLogger (__name__ )
21
31
32
+
22
33
class Model (MultimodalBaseModel ):
34
+ """
35
+ Dynamic multimodal fusion model definition
36
+
37
+ Implements fusion of M modalities through attention-based mechanisms
38
+ Supports heterogeneous input spaces
39
+ # X_m ∈ ℝ^{d_m} for m ∈ M
40
+ """
41
+
23
42
name = "dynamic_fusion"
24
43
44
+ # Model initialisation
25
45
def __init__ (
26
46
self ,
27
47
output_network : AbstractLinearNetwork ,
@@ -106,7 +126,7 @@ def __init__(
106
126
forecast_minutes_ignore = forecast_minutes_ignore ,
107
127
)
108
128
109
- self ._initialize_model_config (
129
+ self ._initialise_model_config (
110
130
include_gsp_yield_history = include_gsp_yield_history ,
111
131
nwp_encoders_dict = nwp_encoders_dict ,
112
132
pv_encoder = pv_encoder ,
@@ -128,7 +148,7 @@ def __init__(
128
148
nwp_history_minutes = nwp_history_minutes
129
149
)
130
150
131
- self .encoder = self ._initialize_fusion_encoder (
151
+ self .encoder = self ._initialise_fusion_encoder (
132
152
modality_channels = modality_channels ,
133
153
fusion_hidden_dim = fusion_hidden_dim ,
134
154
num_fusion_heads = num_fusion_heads ,
@@ -142,13 +162,17 @@ def __init__(
142
162
)
143
163
144
164
self .save_hyperparameters ()
145
- logger .info (f"Initialized { self .name } model with { len (modality_channels )} modalities" )
165
+ logger .info (f"Initialised { self .name } model with { len (modality_channels )} modalities" )
146
166
147
167
def _validate_inputs (self , ** kwargs ):
168
+ """ Validation - architectural hyperparameters / input config """
169
+
148
170
if kwargs ['fusion_hidden_dim' ] <= 0 :
149
171
raise ValueError ("fusion_hidden_dim must be positive" )
172
+
150
173
if kwargs ['num_fusion_heads' ] <= 0 :
151
174
raise ValueError ("num_fusion_heads must be positive" )
175
+
152
176
if kwargs ['fusion_method' ] not in ["weighted_sum" , "concat" ]:
153
177
raise ValueError (f"Invalid fusion method: { kwargs ['fusion_method' ]} " )
154
178
@@ -161,7 +185,9 @@ def _validate_inputs(self, **kwargs):
161
185
if kwargs ['pv_encoder' ] is not None and kwargs ['pv_history_minutes' ] is None :
162
186
raise ValueError ("pv_history_minutes required when using PV encoder" )
163
187
164
- def _initialize_model_config (self , ** kwargs ):
188
+ def _initialise_model_config (self , ** kwargs ):
189
+ """ Configuration of model architecture / modality-specific parameters """
190
+
165
191
config_params = {
166
192
k : v for k , v in kwargs .items ()
167
193
if not k .startswith ('include_' )
@@ -176,8 +202,13 @@ def _initialize_model_config(self, **kwargs):
176
202
self .nwp_encoders_dict = {}
177
203
178
204
def _setup_modality_channels (self , ** kwargs ) -> Dict [str , int ]:
205
+ """ Modality-specific channel configurations """
206
+
179
207
modality_channels = {}
180
208
209
+ # Defines input dimension for each modality
210
+ # Returns mapping
211
+ # m ∈ M → d_m
181
212
if self .embedding_dim :
182
213
modality_channels ["embedding" ] = self .embedding_dim
183
214
@@ -189,6 +220,11 @@ def _setup_modality_channels(self, **kwargs) -> Dict[str, int]:
189
220
return modality_channels
190
221
191
222
def _setup_nwp_channels (self , modality_channels : Dict [str , int ], ** kwargs ):
223
+ """ NWP channel configuration """
224
+
225
+ # Defines temporal sequence length / channel dimension
226
+ # Mapping for NWP features
227
+ # (L,C) → ℝ^{L×C}
192
228
nwp_interval_minutes = kwargs .get ('nwp_interval_minutes' )
193
229
if nwp_interval_minutes is None :
194
230
nwp_interval_minutes = dict .fromkeys (self .nwp_encoders_dict .keys (), 60 )
@@ -210,23 +246,36 @@ def _setup_nwp_channels(self, modality_channels: Dict[str, int], **kwargs):
210
246
modality_channels [f"nwp/{ nwp_source } " ] = nwp_channels
211
247
212
248
def _add_additional_channels (self , modality_channels : Dict [str , int ]):
249
+
213
250
if self .include_pv :
214
251
modality_channels ["pv" ] = self .pv_encoder .keywords .get ("num_sites" , 1 )
252
+
215
253
if self .include_wind :
216
254
modality_channels ["wind" ] = self .wind_encoder .keywords .get ("num_sites" , 1 )
255
+
217
256
if self .include_sensor :
218
257
modality_channels ["sensor" ] = self .sensor_encoder .keywords .get ("num_sites" , 1 )
258
+
219
259
if self .include_sun :
220
260
modality_channels ["sun" ] = self .fusion_hidden_dim
261
+
221
262
if self .include_time :
222
263
modality_channels ["time" ] = self .fusion_hidden_dim
264
+
223
265
if self .include_gsp_yield_history :
224
266
modality_channels ["gsp" ] = self .history_len
225
267
226
- def _initialize_fusion_encoder (self , modality_channels : Dict [str , int ], fusion_hidden_dim : int ,
268
+ def _initialise_fusion_encoder (self , modality_channels : Dict [str , int ], fusion_hidden_dim : int ,
227
269
num_fusion_heads : int , fusion_dropout : float , fusion_method : str ) -> DynamicFusionEncoder :
270
+ """ Initialisation of dynamic fusion encoder """
271
+
228
272
modality_encoders = {}
229
273
274
+ # Modality encoders φ_m: X_m → ℝ^H
275
+ # Cross attention A_c: ⊗_{m∈M} ℝ^H → ℝ^H
276
+ # Modality gating g_m: ℝ^H → [0,1]
277
+ # Dynamic fusion F: {ℝ^H}^M → ℝ^H
278
+
230
279
if self .include_nwp :
231
280
for nwp_source , encoder in self .nwp_encoders_dict .items ():
232
281
modality_encoders [f"nwp/{ nwp_source } " ] = {
@@ -253,9 +302,15 @@ def _initialize_fusion_encoder(self, modality_channels: Dict[str, int], fusion_h
253
302
)
254
303
255
304
def forward (self , x : Dict [str , torch .Tensor ]) -> Tuple [torch .Tensor , Dict [str , torch .Tensor ]]:
305
+ """ Forward pass implementation """
306
+
307
+ # f: X → Y mapping in feature space
308
+ # x_m ∈ X_m → y ∈ ℝ^O
256
309
if self .adapt_batches :
257
310
x = self ._adapt_batch (x )
258
311
312
+ # Input feature collection
313
+ # X = {x_m}_{m∈M}
259
314
inputs = {}
260
315
261
316
if self .include_nwp :
@@ -286,14 +341,62 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, t
286
341
inputs ["time" ] = self .time_fc1 (time_features )
287
342
288
343
encoded_features = self .encoder (inputs )
344
+ print (f"Encoded features shape: { encoded_features .shape } " )
345
+
346
+ # Dimension validation and expansion
347
+ if encoded_features .dim () == 2 :
348
+
349
+ # Single feature expansion
350
+ # π: ℝ^1 → ℝ^H
351
+ if encoded_features .size (1 ) == 1 :
352
+ # Repeat to match hidden dimension
353
+ encoded_features = encoded_features .repeat (1 , self .fusion_hidden_dim )
354
+
355
+
356
+ # Quantile feature preparation
357
+ # Q: ℝ^H → ℝ^{H×q}, q: number of quantiles
358
+ if self .use_quantile_regression and self .output_quantiles :
359
+ num_quantiles = len (self .output_quantiles )
360
+ batch_size = encoded_features .size (0 )
361
+
362
+ # Layer dimension matching
363
+ first_layer = list (self .output_network .layers )[0 ][0 ]
364
+ if hasattr (first_layer , 'in_features' ):
365
+ target_dim = first_layer .in_features
366
+
367
+ # Feature expansion and padding
368
+ # ξ: ℝ^H → ℝ^{H×q}
369
+ quantile_features = encoded_features .repeat (1 , num_quantiles )
370
+
371
+ # Dimension matching via truncation/padding
372
+ # π: ℝ^k → ℝ^d
373
+ if quantile_features .size (1 ) > target_dim :
374
+ quantile_features = quantile_features [:, :target_dim ]
375
+ elif quantile_features .size (1 ) < target_dim :
376
+ padding = torch .zeros (batch_size , target_dim - quantile_features .size (1 ),
377
+ device = quantile_features .device )
378
+ quantile_features = torch .cat ([quantile_features , padding ], dim = 1 )
379
+
380
+ encoded_features = quantile_features
381
+
382
+ # Output generation
383
+ # y = ψ(z)
289
384
output = self .output_network (encoded_features )
290
385
291
- if self .use_quantile_regression :
292
- output = output .reshape (output .shape [0 ], self .forecast_len , len (self .output_quantiles ))
386
+ # Quantile output reshaping
387
+ # ρ: ℝ^{B×T×q} → ℝ^{B×F×q}
388
+ if self .use_quantile_regression and self .output_quantiles :
389
+ output = output .reshape (
390
+ output .shape [0 ],
391
+ self .forecast_len ,
392
+ len (self .output_quantiles )
393
+ )
293
394
294
395
return output , encoded_features
295
396
296
397
def _process_nwp_data (self , x : Dict [str , torch .Tensor ], inputs : Dict [str , torch .Tensor ]):
398
+ """ Process NWP input features """
399
+
297
400
for nwp_source , nwp_encoder in self .nwp_encoders_dict .items ():
298
401
nwp_data = x [BatchKey .nwp ][nwp_source ][NWPBatchKey .nwp ].float ()
299
402
nwp_data = torch .swapaxes (nwp_data , 1 , 2 )
@@ -306,7 +409,13 @@ def _process_nwp_data(self, x: Dict[str, torch.Tensor], inputs: Dict[str, torch.
306
409
inputs [f"nwp/{ nwp_source } " ] = nwp_data
307
410
308
411
def _adapt_batch (self , batch : Dict [str , torch .Tensor ]) -> Dict [str , torch .Tensor ]:
412
+ """
413
+ Batch adaptation for tensor processing
414
+ Maps arbitrary inputs to T
415
+ """
416
+
309
417
adapted_batch = {}
418
+
310
419
for key , value in batch .items ():
311
420
if isinstance (value , (torch .Tensor , dict )):
312
421
adapted_batch [key ] = value
@@ -318,11 +427,17 @@ def _adapt_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
318
427
return adapted_batch
319
428
320
429
def _preprocess_features (self , x : torch .Tensor , modality : str ) -> torch .Tensor :
430
+ """ Modality specific feature preprocessing """
431
+
432
+ # π_m: X_m → X̂_m
321
433
if modality == "nwp" :
322
434
return torch .clip (torch .swapaxes (x , 1 , 2 ), min = - 50 , max = 50 )
323
435
return x .float ()
324
436
325
437
def _prepare_time_features (self , x : Dict [str , torch .Tensor ]) -> torch .Tensor :
438
+ """ Temporal feature preparation - cyclic encoding """
439
+
440
+ # τ: T → ℝ^4
326
441
return torch .cat ((
327
442
x [f"{ self ._target_key_name } _date_sin" ],
328
443
x [f"{ self ._target_key_name } _date_cos" ],
@@ -331,6 +446,9 @@ def _prepare_time_features(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
331
446
), dim = 1 ).float ()
332
447
333
448
def _prepare_sun_features (self , x : Dict [str , torch .Tensor ]) -> torch .Tensor :
449
+ """ Solar feature preparation """
450
+
451
+ # σ: S → ℝ^2
334
452
return torch .cat ((
335
453
x [BatchKey [f"{ self ._target_key_name } _solar_azimuth" ]],
336
454
x [BatchKey [f"{ self ._target_key_name } _solar_elevation" ]],
0 commit comments