Skip to content

Commit bb013bd

Browse files
committed
Dynamic model script added
1 parent 7e74993 commit bb013bd

File tree

1 file changed

+129
-11
lines changed

1 file changed

+129
-11
lines changed

pvnet/models/multimodal/multimodal_dynamic.py

Lines changed: 129 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,47 @@
11
# multimodal_dynamic.py
22

3-
from collections import OrderedDict
4-
from typing import Optional, Dict, List, Tuple, Any, Union
5-
import logging
63

4+
"""
5+
Dynamic multimodal fusion model implementation
6+
7+
Model class for multimodal fusion architecture - integrates multiple modality encoders / fusion mechanisms
8+
9+
Implementation permits dynamic fusion through attention-based mechanisms and modality-specific processing stages
10+
"""
11+
12+
import logging
13+
import pvnet
714
import torch
15+
816
from torch import nn
917
from ocf_datapipes.batch import BatchKey, NWPBatchKey
1018
from omegaconf import DictConfig
19+
from collections import OrderedDict
20+
from typing import Optional, Dict, List, Tuple, Any, Union
1121

12-
import pvnet
1322
from pvnet.models.multimodal.basic_blocks import ImageEmbedding
1423
from pvnet.models.multimodal.encoders.dynamic_encoder import DynamicFusionEncoder
1524
from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
1625
from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractPVSitesEncoder
1726
from pvnet.models.multimodal.multimodal_base import MultimodalBaseModel
1827
from pvnet.optimizers import AbstractOptimizer
1928

29+
2030
logger = logging.getLogger(__name__)
2131

32+
2233
class Model(MultimodalBaseModel):
34+
"""
35+
Dynamic multimodal fusion model definition
36+
37+
Implements fusion of M modalities through attention-based mechanisms
38+
Supports heterogeneous input spaces
39+
# X_m ∈ ℝ^{d_m} for m ∈ M
40+
"""
41+
2342
name = "dynamic_fusion"
2443

44+
# Model initialisation
2545
def __init__(
2646
self,
2747
output_network: AbstractLinearNetwork,
@@ -106,7 +126,7 @@ def __init__(
106126
forecast_minutes_ignore=forecast_minutes_ignore,
107127
)
108128

109-
self._initialize_model_config(
129+
self._initialise_model_config(
110130
include_gsp_yield_history=include_gsp_yield_history,
111131
nwp_encoders_dict=nwp_encoders_dict,
112132
pv_encoder=pv_encoder,
@@ -128,7 +148,7 @@ def __init__(
128148
nwp_history_minutes=nwp_history_minutes
129149
)
130150

131-
self.encoder = self._initialize_fusion_encoder(
151+
self.encoder = self._initialise_fusion_encoder(
132152
modality_channels=modality_channels,
133153
fusion_hidden_dim=fusion_hidden_dim,
134154
num_fusion_heads=num_fusion_heads,
@@ -142,13 +162,17 @@ def __init__(
142162
)
143163

144164
self.save_hyperparameters()
145-
logger.info(f"Initialized {self.name} model with {len(modality_channels)} modalities")
165+
logger.info(f"Initialised {self.name} model with {len(modality_channels)} modalities")
146166

147167
def _validate_inputs(self, **kwargs):
168+
""" Validation - architectural hyperparameters / input config """
169+
148170
if kwargs['fusion_hidden_dim'] <= 0:
149171
raise ValueError("fusion_hidden_dim must be positive")
172+
150173
if kwargs['num_fusion_heads'] <= 0:
151174
raise ValueError("num_fusion_heads must be positive")
175+
152176
if kwargs['fusion_method'] not in ["weighted_sum", "concat"]:
153177
raise ValueError(f"Invalid fusion method: {kwargs['fusion_method']}")
154178

@@ -161,7 +185,9 @@ def _validate_inputs(self, **kwargs):
161185
if kwargs['pv_encoder'] is not None and kwargs['pv_history_minutes'] is None:
162186
raise ValueError("pv_history_minutes required when using PV encoder")
163187

164-
def _initialize_model_config(self, **kwargs):
188+
def _initialise_model_config(self, **kwargs):
189+
""" Configuration of model architecture / modality-specific parameters """
190+
165191
config_params = {
166192
k: v for k, v in kwargs.items()
167193
if not k.startswith('include_')
@@ -176,8 +202,13 @@ def _initialize_model_config(self, **kwargs):
176202
self.nwp_encoders_dict = {}
177203

178204
def _setup_modality_channels(self, **kwargs) -> Dict[str, int]:
205+
""" Modality-specific channel configurations """
206+
179207
modality_channels = {}
180208

209+
# Defines input dimension for each modality
210+
# Returns mapping
211+
# m ∈ M → d_m
181212
if self.embedding_dim:
182213
modality_channels["embedding"] = self.embedding_dim
183214

@@ -189,6 +220,11 @@ def _setup_modality_channels(self, **kwargs) -> Dict[str, int]:
189220
return modality_channels
190221

191222
def _setup_nwp_channels(self, modality_channels: Dict[str, int], **kwargs):
223+
""" NWP channel configuration """
224+
225+
# Defines temporal sequence length / channel dimension
226+
# Mapping for NWP features
227+
# (L,C) → ℝ^{L×C}
192228
nwp_interval_minutes = kwargs.get('nwp_interval_minutes')
193229
if nwp_interval_minutes is None:
194230
nwp_interval_minutes = dict.fromkeys(self.nwp_encoders_dict.keys(), 60)
@@ -210,23 +246,36 @@ def _setup_nwp_channels(self, modality_channels: Dict[str, int], **kwargs):
210246
modality_channels[f"nwp/{nwp_source}"] = nwp_channels
211247

212248
def _add_additional_channels(self, modality_channels: Dict[str, int]):
249+
213250
if self.include_pv:
214251
modality_channels["pv"] = self.pv_encoder.keywords.get("num_sites", 1)
252+
215253
if self.include_wind:
216254
modality_channels["wind"] = self.wind_encoder.keywords.get("num_sites", 1)
255+
217256
if self.include_sensor:
218257
modality_channels["sensor"] = self.sensor_encoder.keywords.get("num_sites", 1)
258+
219259
if self.include_sun:
220260
modality_channels["sun"] = self.fusion_hidden_dim
261+
221262
if self.include_time:
222263
modality_channels["time"] = self.fusion_hidden_dim
264+
223265
if self.include_gsp_yield_history:
224266
modality_channels["gsp"] = self.history_len
225267

226-
def _initialize_fusion_encoder(self, modality_channels: Dict[str, int], fusion_hidden_dim: int,
268+
def _initialise_fusion_encoder(self, modality_channels: Dict[str, int], fusion_hidden_dim: int,
227269
num_fusion_heads: int, fusion_dropout: float, fusion_method: str) -> DynamicFusionEncoder:
270+
""" Initialisation of dynamic fusion encoder """
271+
228272
modality_encoders = {}
229273

274+
# Modality encoders φ_m: X_m → ℝ^H
275+
# Cross attention A_c: ⊗_{m∈M} ℝ^H → ℝ^H
276+
# Modality gating g_m: ℝ^H → [0,1]
277+
# Dynamic fusion F: {ℝ^H}^M → ℝ^H
278+
230279
if self.include_nwp:
231280
for nwp_source, encoder in self.nwp_encoders_dict.items():
232281
modality_encoders[f"nwp/{nwp_source}"] = {
@@ -253,9 +302,15 @@ def _initialize_fusion_encoder(self, modality_channels: Dict[str, int], fusion_h
253302
)
254303

255304
def forward(self, x: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
305+
""" Forward pass implementation """
306+
307+
# f: X → Y mapping in feature space
308+
# x_m ∈ X_m → y ∈ ℝ^O
256309
if self.adapt_batches:
257310
x = self._adapt_batch(x)
258311

312+
# Input feature collection
313+
# X = {x_m}_{m∈M}
259314
inputs = {}
260315

261316
if self.include_nwp:
@@ -286,14 +341,62 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, t
286341
inputs["time"] = self.time_fc1(time_features)
287342

288343
encoded_features = self.encoder(inputs)
344+
print(f"Encoded features shape: {encoded_features.shape}")
345+
346+
# Dimension validation and expansion
347+
if encoded_features.dim() == 2:
348+
349+
# Single feature expansion
350+
# π: ℝ^1 → ℝ^H
351+
if encoded_features.size(1) == 1:
352+
# Repeat to match hidden dimension
353+
encoded_features = encoded_features.repeat(1, self.fusion_hidden_dim)
354+
355+
356+
# Quantile feature preparation
357+
# Q: ℝ^H → ℝ^{H×q}, q: number of quantiles
358+
if self.use_quantile_regression and self.output_quantiles:
359+
num_quantiles = len(self.output_quantiles)
360+
batch_size = encoded_features.size(0)
361+
362+
# Layer dimension matching
363+
first_layer = list(self.output_network.layers)[0][0]
364+
if hasattr(first_layer, 'in_features'):
365+
target_dim = first_layer.in_features
366+
367+
# Feature expansion and padding
368+
# ξ: ℝ^H → ℝ^{H×q}
369+
quantile_features = encoded_features.repeat(1, num_quantiles)
370+
371+
# Dimension matching via truncation/padding
372+
# π: ℝ^k → ℝ^d
373+
if quantile_features.size(1) > target_dim:
374+
quantile_features = quantile_features[:, :target_dim]
375+
elif quantile_features.size(1) < target_dim:
376+
padding = torch.zeros(batch_size, target_dim - quantile_features.size(1),
377+
device=quantile_features.device)
378+
quantile_features = torch.cat([quantile_features, padding], dim=1)
379+
380+
encoded_features = quantile_features
381+
382+
# Output generation
383+
# y = ψ(z)
289384
output = self.output_network(encoded_features)
290385

291-
if self.use_quantile_regression:
292-
output = output.reshape(output.shape[0], self.forecast_len, len(self.output_quantiles))
386+
# Quantile output reshaping
387+
# ρ: ℝ^{B×T×q} → ℝ^{B×F×q}
388+
if self.use_quantile_regression and self.output_quantiles:
389+
output = output.reshape(
390+
output.shape[0],
391+
self.forecast_len,
392+
len(self.output_quantiles)
393+
)
293394

294395
return output, encoded_features
295396

296397
def _process_nwp_data(self, x: Dict[str, torch.Tensor], inputs: Dict[str, torch.Tensor]):
398+
""" Process NWP input features """
399+
297400
for nwp_source, nwp_encoder in self.nwp_encoders_dict.items():
298401
nwp_data = x[BatchKey.nwp][nwp_source][NWPBatchKey.nwp].float()
299402
nwp_data = torch.swapaxes(nwp_data, 1, 2)
@@ -306,7 +409,13 @@ def _process_nwp_data(self, x: Dict[str, torch.Tensor], inputs: Dict[str, torch.
306409
inputs[f"nwp/{nwp_source}"] = nwp_data
307410

308411
def _adapt_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
412+
"""
413+
Batch adaptation for tensor processing
414+
Maps arbitrary inputs to T
415+
"""
416+
309417
adapted_batch = {}
418+
310419
for key, value in batch.items():
311420
if isinstance(value, (torch.Tensor, dict)):
312421
adapted_batch[key] = value
@@ -318,11 +427,17 @@ def _adapt_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
318427
return adapted_batch
319428

320429
def _preprocess_features(self, x: torch.Tensor, modality: str) -> torch.Tensor:
430+
""" Modality specific feature preprocessing """
431+
432+
# π_m: X_m → X̂_m
321433
if modality == "nwp":
322434
return torch.clip(torch.swapaxes(x, 1, 2), min=-50, max=50)
323435
return x.float()
324436

325437
def _prepare_time_features(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
438+
""" Temporal feature preparation - cyclic encoding """
439+
440+
# τ: T → ℝ^4
326441
return torch.cat((
327442
x[f"{self._target_key_name}_date_sin"],
328443
x[f"{self._target_key_name}_date_cos"],
@@ -331,6 +446,9 @@ def _prepare_time_features(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
331446
), dim=1).float()
332447

333448
def _prepare_sun_features(self, x: Dict[str, torch.Tensor]) -> torch.Tensor:
449+
""" Solar feature preparation """
450+
451+
# σ: S → ℝ^2
334452
return torch.cat((
335453
x[BatchKey[f"{self._target_key_name}_solar_azimuth"]],
336454
x[BatchKey[f"{self._target_key_name}_solar_elevation"]],

0 commit comments

Comments
 (0)