|
| 1 | +# multimodal_dynamic.py |
| 2 | + |
| 3 | +""" |
| 4 | +Dynamic fusion model definition |
| 5 | +""" |
| 6 | + |
| 7 | +from collections import OrderedDict |
| 8 | +from typing import Optional, Dict |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch import nn |
| 12 | +from ocf_datapipes.batch import BatchKey, NWPBatchKey |
| 13 | +from omegaconf import DictConfig |
| 14 | + |
| 15 | +import pvnet |
| 16 | +from pvnet.models.multimodal.basic_blocks import ImageEmbedding |
| 17 | +from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder |
| 18 | +from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork |
| 19 | +from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractPVSitesEncoder |
| 20 | +from pvnet.models.multimodal.multimodal_base import MultimodalBaseModel |
| 21 | +from pvnet.optimizers import AbstractOptimizer |
| 22 | +from pvnet.models.multimodal.fusion_blocks import DynamicFusionModule |
| 23 | +from pvnet.models.multimodal.attention_blocks import CrossModalAttention |
| 24 | + |
| 25 | + |
| 26 | +class Model(MultimodalBaseModel): |
| 27 | + """ |
| 28 | + Architecture summarised as follows: |
| 29 | +
|
| 30 | + - Each modality encoded separately |
| 31 | + - Cross modal attention - early feature interaction |
| 32 | + - Dynamic weighting - modality importance |
| 33 | + - Weighted combination - final fused representation |
| 34 | + """ |
| 35 | + |
| 36 | + name = "dynamic_fusion" |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + output_network: AbstractLinearNetwork, |
| 41 | + output_quantiles: Optional[list[float]] = None, |
| 42 | + nwp_encoders_dict: Optional[dict[AbstractNWPSatelliteEncoder]] = None, |
| 43 | + sat_encoder: Optional[AbstractNWPSatelliteEncoder] = None, |
| 44 | + pv_encoder: Optional[AbstractPVSitesEncoder] = None, |
| 45 | + wind_encoder: Optional[AbstractPVSitesEncoder] = None, |
| 46 | + sensor_encoder: Optional[AbstractPVSitesEncoder] = None, |
| 47 | + add_image_embedding_channel: bool = False, |
| 48 | + include_gsp_yield_history: bool = True, |
| 49 | + include_sun: bool = True, |
| 50 | + include_time: bool = False, |
| 51 | + embedding_dim: Optional[int] = 16, |
| 52 | + fusion_hidden_dim: int = 256, |
| 53 | + num_fusion_heads: int = 8, |
| 54 | + fusion_dropout: float = 0.1, |
| 55 | + use_cross_attention: bool = True, |
| 56 | + fusion_method: str = "weighted_sum", |
| 57 | + forecast_minutes: int = 30, |
| 58 | + history_minutes: int = 60, |
| 59 | + sat_history_minutes: Optional[int] = None, |
| 60 | + min_sat_delay_minutes: Optional[int] = 30, |
| 61 | + nwp_forecast_minutes: Optional[DictConfig] = None, |
| 62 | + nwp_history_minutes: Optional[DictConfig] = None, |
| 63 | + pv_history_minutes: Optional[int] = None, |
| 64 | + wind_history_minutes: Optional[int] = None, |
| 65 | + sensor_history_minutes: Optional[int] = None, |
| 66 | + sensor_forecast_minutes: Optional[int] = None, |
| 67 | + optimizer: AbstractOptimizer = pvnet.optimizers.Adam(), |
| 68 | + target_key: str = "gsp", |
| 69 | + interval_minutes: int = 30, |
| 70 | + nwp_interval_minutes: Optional[DictConfig] = None, |
| 71 | + pv_interval_minutes: int = 5, |
| 72 | + sat_interval_minutes: int = 5, |
| 73 | + sensor_interval_minutes: int = 30, |
| 74 | + wind_interval_minutes: int = 15, |
| 75 | + num_embeddings: Optional[int] = 318, |
| 76 | + timestep_intervals_to_plot: Optional[list[int]] = None, |
| 77 | + adapt_batches: Optional[bool] = False, |
| 78 | + use_weighted_loss: Optional[bool] = False, |
| 79 | + forecast_minutes_ignore: Optional[int] = 0, |
| 80 | + ): |
| 81 | + |
| 82 | + self.include_gsp_yield_history = include_gsp_yield_history |
| 83 | + self.include_sat = sat_encoder is not None |
| 84 | + self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0 |
| 85 | + self.include_pv = pv_encoder is not None |
| 86 | + self.include_sun = include_sun |
| 87 | + self.include_time = include_time |
| 88 | + self.include_wind = wind_encoder is not None |
| 89 | + self.include_sensor = sensor_encoder is not None |
| 90 | + self.embedding_dim = embedding_dim |
| 91 | + self.add_image_embedding_channel = add_image_embedding_channel |
| 92 | + self.interval_minutes = interval_minutes |
| 93 | + self.min_sat_delay_minutes = min_sat_delay_minutes |
| 94 | + self.adapt_batches = adapt_batches |
| 95 | + |
| 96 | + super().__init__( |
| 97 | + history_minutes=history_minutes, |
| 98 | + forecast_minutes=forecast_minutes, |
| 99 | + optimizer=optimizer, |
| 100 | + output_quantiles=output_quantiles, |
| 101 | + target_key=target_key, |
| 102 | + interval_minutes=interval_minutes, |
| 103 | + timestep_intervals_to_plot=timestep_intervals_to_plot, |
| 104 | + use_weighted_loss=use_weighted_loss, |
| 105 | + forecast_minutes_ignore=forecast_minutes_ignore, |
| 106 | + ) |
| 107 | + |
| 108 | + feature_dims = {} |
| 109 | + |
| 110 | + if self.include_sat: |
| 111 | + assert sat_history_minutes is not None |
| 112 | + |
| 113 | + self.sat_sequence_len = ( |
| 114 | + sat_history_minutes - min_sat_delay_minutes |
| 115 | + ) // sat_interval_minutes + 1 |
| 116 | + |
| 117 | + self.sat_encoder = sat_encoder( |
| 118 | + sequence_length=self.sat_sequence_len, |
| 119 | + in_channels=sat_encoder.keywords["in_channels"] + add_image_embedding_channel, |
| 120 | + ) |
| 121 | + if add_image_embedding_channel: |
| 122 | + self.sat_embed = ImageEmbedding( |
| 123 | + num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels |
| 124 | + ) |
| 125 | + |
| 126 | + feature_dims["sat"] = self.sat_encoder.out_features |
| 127 | + |
| 128 | + if self.include_nwp: |
| 129 | + assert nwp_forecast_minutes is not None |
| 130 | + assert nwp_history_minutes is not None |
| 131 | + |
| 132 | + assert set(nwp_encoders_dict.keys()) == set(nwp_forecast_minutes.keys()) |
| 133 | + assert set(nwp_encoders_dict.keys()) == set(nwp_history_minutes.keys()) |
| 134 | + |
| 135 | + if nwp_interval_minutes is None: |
| 136 | + nwp_interval_minutes = dict.fromkeys(nwp_encoders_dict.keys(), 60) |
| 137 | + |
| 138 | + self.nwp_encoders_dict = torch.nn.ModuleDict() |
| 139 | + if add_image_embedding_channel: |
| 140 | + self.nwp_embed_dict = torch.nn.ModuleDict() |
| 141 | + |
| 142 | + for nwp_source in nwp_encoders_dict.keys(): |
| 143 | + nwp_sequence_len = ( |
| 144 | + nwp_history_minutes[nwp_source] // nwp_interval_minutes[nwp_source] |
| 145 | + + nwp_forecast_minutes[nwp_source] // nwp_interval_minutes[nwp_source] |
| 146 | + + 1 |
| 147 | + ) |
| 148 | + |
| 149 | + self.nwp_encoders_dict[nwp_source] = nwp_encoders_dict[nwp_source]( |
| 150 | + sequence_length=nwp_sequence_len, |
| 151 | + in_channels=( |
| 152 | + nwp_encoders_dict[nwp_source].keywords["in_channels"] |
| 153 | + + add_image_embedding_channel |
| 154 | + ), |
| 155 | + ) |
| 156 | + if add_image_embedding_channel: |
| 157 | + self.nwp_embed_dict[nwp_source] = ImageEmbedding( |
| 158 | + num_embeddings, |
| 159 | + nwp_sequence_len, |
| 160 | + self.nwp_encoders_dict[nwp_source].image_size_pixels, |
| 161 | + ) |
| 162 | + |
| 163 | + feature_dims[f"nwp/{nwp_source}"] = self.nwp_encoders_dict[nwp_source].out_features |
| 164 | + |
| 165 | + if self.include_pv: |
| 166 | + assert pv_history_minutes is not None |
| 167 | + |
| 168 | + self.pv_encoder = pv_encoder( |
| 169 | + sequence_length=pv_history_minutes // pv_interval_minutes + 1, |
| 170 | + target_key_to_use=self._target_key_name, |
| 171 | + input_key_to_use="pv", |
| 172 | + ) |
| 173 | + |
| 174 | + feature_dims["pv"] = self.pv_encoder.out_features |
| 175 | + |
| 176 | + if self.include_wind: |
| 177 | + if wind_history_minutes is None: |
| 178 | + wind_history_minutes = history_minutes |
| 179 | + |
| 180 | + self.wind_encoder = wind_encoder( |
| 181 | + sequence_length=wind_history_minutes // wind_interval_minutes + 1, |
| 182 | + target_key_to_use=self._target_key_name, |
| 183 | + input_key_to_use="wind", |
| 184 | + ) |
| 185 | + |
| 186 | + feature_dims["wind"] = self.wind_encoder.out_features |
| 187 | + |
| 188 | + if self.include_sensor: |
| 189 | + if sensor_history_minutes is None: |
| 190 | + sensor_history_minutes = history_minutes |
| 191 | + if sensor_forecast_minutes is None: |
| 192 | + sensor_forecast_minutes = forecast_minutes |
| 193 | + |
| 194 | + self.sensor_encoder = sensor_encoder( |
| 195 | + sequence_length=sensor_history_minutes // sensor_interval_minutes |
| 196 | + + sensor_forecast_minutes // sensor_interval_minutes |
| 197 | + + 1, |
| 198 | + target_key_to_use=self._target_key_name, |
| 199 | + input_key_to_use="sensor", |
| 200 | + ) |
| 201 | + |
| 202 | + feature_dims["sensor"] = self.sensor_encoder.out_features |
| 203 | + |
| 204 | + if self.embedding_dim: |
| 205 | + self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) |
| 206 | + feature_dims["embedding"] = embedding_dim |
| 207 | + |
| 208 | + if self.include_sun: |
| 209 | + self.sun_fc1 = nn.Linear( |
| 210 | + in_features=2 * (self.forecast_len + self.forecast_len_ignore + self.history_len + 1), |
| 211 | + out_features=16, |
| 212 | + ) |
| 213 | + feature_dims["sun"] = 16 |
| 214 | + |
| 215 | + if self.include_time: |
| 216 | + self.time_fc1 = nn.Linear( |
| 217 | + in_features=4 * (self.forecast_len + self.forecast_len_ignore + self.history_len + 1), |
| 218 | + out_features=32, |
| 219 | + ) |
| 220 | + feature_dims["time"] = 32 |
| 221 | + |
| 222 | + if include_gsp_yield_history: |
| 223 | + feature_dims["gsp"] = self.history_len |
| 224 | + |
| 225 | + self.fusion_module = DynamicFusionModule( |
| 226 | + feature_dims=feature_dims, |
| 227 | + hidden_dim=fusion_hidden_dim, |
| 228 | + num_heads=num_fusion_heads, |
| 229 | + dropout=fusion_dropout, |
| 230 | + fusion_method=fusion_method, |
| 231 | + use_residual=True |
| 232 | + ) |
| 233 | + |
| 234 | + if use_cross_attention: |
| 235 | + self.cross_attention = CrossModalAttention( |
| 236 | + embed_dim=fusion_hidden_dim, |
| 237 | + num_heads=num_fusion_heads, |
| 238 | + dropout=fusion_dropout, |
| 239 | + num_modalities=len(feature_dims) |
| 240 | + ) |
| 241 | + else: |
| 242 | + self.cross_attention = None |
| 243 | + |
| 244 | + self.output_network = output_network( |
| 245 | + in_features=fusion_hidden_dim, |
| 246 | + out_features=self.num_output_features, |
| 247 | + ) |
| 248 | + |
| 249 | + self.save_hyperparameters() |
| 250 | + |
| 251 | + def forward(self, x): |
| 252 | + |
| 253 | + if self.adapt_batches: |
| 254 | + x = self._adapt_batch(x) |
| 255 | + |
| 256 | + encoded_features = OrderedDict() |
| 257 | + |
| 258 | + if self.include_sat: |
| 259 | + sat_data = x[BatchKey.satellite_actual][:, : self.sat_sequence_len] |
| 260 | + sat_data = torch.swapaxes(sat_data, 1, 2).float() |
| 261 | + |
| 262 | + if self.add_image_embedding_channel: |
| 263 | + id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int() |
| 264 | + sat_data = self.sat_embed(sat_data, id) |
| 265 | + encoded_features["sat"] = self.sat_encoder(sat_data) |
| 266 | + |
| 267 | + if self.include_nwp: |
| 268 | + for nwp_source in self.nwp_encoders_dict: |
| 269 | + nwp_data = x[BatchKey.nwp][nwp_source][NWPBatchKey.nwp].float() |
| 270 | + nwp_data = torch.swapaxes(nwp_data, 1, 2) |
| 271 | + nwp_data = torch.clip(nwp_data, min=-50, max=50) |
| 272 | + |
| 273 | + if self.add_image_embedding_channel: |
| 274 | + id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int() |
| 275 | + nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id) |
| 276 | + |
| 277 | + encoded_features[f"nwp/{nwp_source}"] = self.nwp_encoders_dict[nwp_source](nwp_data) |
| 278 | + |
| 279 | + if self.include_pv: |
| 280 | + if self._target_key_name != "pv": |
| 281 | + encoded_features["pv"] = self.pv_encoder(x) |
| 282 | + else: |
| 283 | + x_tmp = x.copy() |
| 284 | + x_tmp[BatchKey.pv] = x_tmp[BatchKey.pv][:, : self.history_len + 1] |
| 285 | + encoded_features["pv"] = self.pv_encoder(x_tmp) |
| 286 | + |
| 287 | + if self.include_gsp_yield_history: |
| 288 | + gsp_history = x[BatchKey.gsp][:, : self.history_len].float() |
| 289 | + encoded_features["gsp"] = gsp_history.reshape(gsp_history.shape[0], -1) |
| 290 | + |
| 291 | + if self.embedding_dim: |
| 292 | + id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int() |
| 293 | + encoded_features["embedding"] = self.embed(id) |
| 294 | + |
| 295 | + if self.include_wind: |
| 296 | + if self._target_key_name != "wind": |
| 297 | + encoded_features["wind"] = self.wind_encoder(x) |
| 298 | + else: |
| 299 | + x_tmp = x.copy() |
| 300 | + x_tmp[BatchKey.wind] = x_tmp[BatchKey.wind][:, : self.history_len + 1] |
| 301 | + encoded_features["wind"] = self.wind_encoder(x_tmp) |
| 302 | + |
| 303 | + if self.include_sensor: |
| 304 | + if self._target_key_name != "sensor": |
| 305 | + encoded_features["sensor"] = self.sensor_encoder(x) |
| 306 | + else: |
| 307 | + x_tmp = x.copy() |
| 308 | + x_tmp[BatchKey.sensor] = x_tmp[BatchKey.sensor][:, : self.history_len + 1] |
| 309 | + encoded_features["sensor"] = self.sensor_encoder(x_tmp) |
| 310 | + |
| 311 | + if self.include_sun: |
| 312 | + sun = torch.cat( |
| 313 | + ( |
| 314 | + x[BatchKey[f"{self._target_key_name}_solar_azimuth"]], |
| 315 | + x[BatchKey[f"{self._target_key_name}_solar_elevation"]], |
| 316 | + ), |
| 317 | + dim=1, |
| 318 | + ).float() |
| 319 | + encoded_features["sun"] = self.sun_fc1(sun) |
| 320 | + |
| 321 | + |
| 322 | + if self.include_time: |
| 323 | + time = torch.cat( |
| 324 | + ( |
| 325 | + x[f"{self._target_key_name}_date_sin"], |
| 326 | + x[f"{self._target_key_name}_date_cos"], |
| 327 | + x[f"{self._target_key_name}_time_sin"], |
| 328 | + x[f"{self._target_key_name}_time_cos"], |
| 329 | + ), |
| 330 | + dim=1, |
| 331 | + ).float() |
| 332 | + encoded_features["time"] = self.time_fc1(time) |
| 333 | + |
| 334 | + if self.cross_attention is not None and len(encoded_features) > 1: |
| 335 | + encoded_features = self.cross_attention(encoded_features) |
| 336 | + |
| 337 | + fused_features = self.fusion_module(encoded_features) |
| 338 | + out = self.output_network(fused_features) |
| 339 | + |
| 340 | + if self.use_quantile_regression: |
| 341 | + out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles)) |
| 342 | + |
| 343 | + return out |
0 commit comments