Skip to content

Commit 227f988

Browse files
committed
Multimodal dynamic - 'draft'
1 parent 3b8d59d commit 227f988

File tree

1 file changed

+343
-0
lines changed

1 file changed

+343
-0
lines changed
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
# multimodal_dynamic.py
2+
3+
"""
4+
Dynamic fusion model definition
5+
"""
6+
7+
from collections import OrderedDict
8+
from typing import Optional, Dict
9+
10+
import torch
11+
from torch import nn
12+
from ocf_datapipes.batch import BatchKey, NWPBatchKey
13+
from omegaconf import DictConfig
14+
15+
import pvnet
16+
from pvnet.models.multimodal.basic_blocks import ImageEmbedding
17+
from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder
18+
from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
19+
from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractPVSitesEncoder
20+
from pvnet.models.multimodal.multimodal_base import MultimodalBaseModel
21+
from pvnet.optimizers import AbstractOptimizer
22+
from pvnet.models.multimodal.fusion_blocks import DynamicFusionModule
23+
from pvnet.models.multimodal.attention_blocks import CrossModalAttention
24+
25+
26+
class Model(MultimodalBaseModel):
27+
"""
28+
Architecture summarised as follows:
29+
30+
- Each modality encoded separately
31+
- Cross modal attention - early feature interaction
32+
- Dynamic weighting - modality importance
33+
- Weighted combination - final fused representation
34+
"""
35+
36+
name = "dynamic_fusion"
37+
38+
def __init__(
39+
self,
40+
output_network: AbstractLinearNetwork,
41+
output_quantiles: Optional[list[float]] = None,
42+
nwp_encoders_dict: Optional[dict[AbstractNWPSatelliteEncoder]] = None,
43+
sat_encoder: Optional[AbstractNWPSatelliteEncoder] = None,
44+
pv_encoder: Optional[AbstractPVSitesEncoder] = None,
45+
wind_encoder: Optional[AbstractPVSitesEncoder] = None,
46+
sensor_encoder: Optional[AbstractPVSitesEncoder] = None,
47+
add_image_embedding_channel: bool = False,
48+
include_gsp_yield_history: bool = True,
49+
include_sun: bool = True,
50+
include_time: bool = False,
51+
embedding_dim: Optional[int] = 16,
52+
fusion_hidden_dim: int = 256,
53+
num_fusion_heads: int = 8,
54+
fusion_dropout: float = 0.1,
55+
use_cross_attention: bool = True,
56+
fusion_method: str = "weighted_sum",
57+
forecast_minutes: int = 30,
58+
history_minutes: int = 60,
59+
sat_history_minutes: Optional[int] = None,
60+
min_sat_delay_minutes: Optional[int] = 30,
61+
nwp_forecast_minutes: Optional[DictConfig] = None,
62+
nwp_history_minutes: Optional[DictConfig] = None,
63+
pv_history_minutes: Optional[int] = None,
64+
wind_history_minutes: Optional[int] = None,
65+
sensor_history_minutes: Optional[int] = None,
66+
sensor_forecast_minutes: Optional[int] = None,
67+
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
68+
target_key: str = "gsp",
69+
interval_minutes: int = 30,
70+
nwp_interval_minutes: Optional[DictConfig] = None,
71+
pv_interval_minutes: int = 5,
72+
sat_interval_minutes: int = 5,
73+
sensor_interval_minutes: int = 30,
74+
wind_interval_minutes: int = 15,
75+
num_embeddings: Optional[int] = 318,
76+
timestep_intervals_to_plot: Optional[list[int]] = None,
77+
adapt_batches: Optional[bool] = False,
78+
use_weighted_loss: Optional[bool] = False,
79+
forecast_minutes_ignore: Optional[int] = 0,
80+
):
81+
82+
self.include_gsp_yield_history = include_gsp_yield_history
83+
self.include_sat = sat_encoder is not None
84+
self.include_nwp = nwp_encoders_dict is not None and len(nwp_encoders_dict) != 0
85+
self.include_pv = pv_encoder is not None
86+
self.include_sun = include_sun
87+
self.include_time = include_time
88+
self.include_wind = wind_encoder is not None
89+
self.include_sensor = sensor_encoder is not None
90+
self.embedding_dim = embedding_dim
91+
self.add_image_embedding_channel = add_image_embedding_channel
92+
self.interval_minutes = interval_minutes
93+
self.min_sat_delay_minutes = min_sat_delay_minutes
94+
self.adapt_batches = adapt_batches
95+
96+
super().__init__(
97+
history_minutes=history_minutes,
98+
forecast_minutes=forecast_minutes,
99+
optimizer=optimizer,
100+
output_quantiles=output_quantiles,
101+
target_key=target_key,
102+
interval_minutes=interval_minutes,
103+
timestep_intervals_to_plot=timestep_intervals_to_plot,
104+
use_weighted_loss=use_weighted_loss,
105+
forecast_minutes_ignore=forecast_minutes_ignore,
106+
)
107+
108+
feature_dims = {}
109+
110+
if self.include_sat:
111+
assert sat_history_minutes is not None
112+
113+
self.sat_sequence_len = (
114+
sat_history_minutes - min_sat_delay_minutes
115+
) // sat_interval_minutes + 1
116+
117+
self.sat_encoder = sat_encoder(
118+
sequence_length=self.sat_sequence_len,
119+
in_channels=sat_encoder.keywords["in_channels"] + add_image_embedding_channel,
120+
)
121+
if add_image_embedding_channel:
122+
self.sat_embed = ImageEmbedding(
123+
num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels
124+
)
125+
126+
feature_dims["sat"] = self.sat_encoder.out_features
127+
128+
if self.include_nwp:
129+
assert nwp_forecast_minutes is not None
130+
assert nwp_history_minutes is not None
131+
132+
assert set(nwp_encoders_dict.keys()) == set(nwp_forecast_minutes.keys())
133+
assert set(nwp_encoders_dict.keys()) == set(nwp_history_minutes.keys())
134+
135+
if nwp_interval_minutes is None:
136+
nwp_interval_minutes = dict.fromkeys(nwp_encoders_dict.keys(), 60)
137+
138+
self.nwp_encoders_dict = torch.nn.ModuleDict()
139+
if add_image_embedding_channel:
140+
self.nwp_embed_dict = torch.nn.ModuleDict()
141+
142+
for nwp_source in nwp_encoders_dict.keys():
143+
nwp_sequence_len = (
144+
nwp_history_minutes[nwp_source] // nwp_interval_minutes[nwp_source]
145+
+ nwp_forecast_minutes[nwp_source] // nwp_interval_minutes[nwp_source]
146+
+ 1
147+
)
148+
149+
self.nwp_encoders_dict[nwp_source] = nwp_encoders_dict[nwp_source](
150+
sequence_length=nwp_sequence_len,
151+
in_channels=(
152+
nwp_encoders_dict[nwp_source].keywords["in_channels"]
153+
+ add_image_embedding_channel
154+
),
155+
)
156+
if add_image_embedding_channel:
157+
self.nwp_embed_dict[nwp_source] = ImageEmbedding(
158+
num_embeddings,
159+
nwp_sequence_len,
160+
self.nwp_encoders_dict[nwp_source].image_size_pixels,
161+
)
162+
163+
feature_dims[f"nwp/{nwp_source}"] = self.nwp_encoders_dict[nwp_source].out_features
164+
165+
if self.include_pv:
166+
assert pv_history_minutes is not None
167+
168+
self.pv_encoder = pv_encoder(
169+
sequence_length=pv_history_minutes // pv_interval_minutes + 1,
170+
target_key_to_use=self._target_key_name,
171+
input_key_to_use="pv",
172+
)
173+
174+
feature_dims["pv"] = self.pv_encoder.out_features
175+
176+
if self.include_wind:
177+
if wind_history_minutes is None:
178+
wind_history_minutes = history_minutes
179+
180+
self.wind_encoder = wind_encoder(
181+
sequence_length=wind_history_minutes // wind_interval_minutes + 1,
182+
target_key_to_use=self._target_key_name,
183+
input_key_to_use="wind",
184+
)
185+
186+
feature_dims["wind"] = self.wind_encoder.out_features
187+
188+
if self.include_sensor:
189+
if sensor_history_minutes is None:
190+
sensor_history_minutes = history_minutes
191+
if sensor_forecast_minutes is None:
192+
sensor_forecast_minutes = forecast_minutes
193+
194+
self.sensor_encoder = sensor_encoder(
195+
sequence_length=sensor_history_minutes // sensor_interval_minutes
196+
+ sensor_forecast_minutes // sensor_interval_minutes
197+
+ 1,
198+
target_key_to_use=self._target_key_name,
199+
input_key_to_use="sensor",
200+
)
201+
202+
feature_dims["sensor"] = self.sensor_encoder.out_features
203+
204+
if self.embedding_dim:
205+
self.embed = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
206+
feature_dims["embedding"] = embedding_dim
207+
208+
if self.include_sun:
209+
self.sun_fc1 = nn.Linear(
210+
in_features=2 * (self.forecast_len + self.forecast_len_ignore + self.history_len + 1),
211+
out_features=16,
212+
)
213+
feature_dims["sun"] = 16
214+
215+
if self.include_time:
216+
self.time_fc1 = nn.Linear(
217+
in_features=4 * (self.forecast_len + self.forecast_len_ignore + self.history_len + 1),
218+
out_features=32,
219+
)
220+
feature_dims["time"] = 32
221+
222+
if include_gsp_yield_history:
223+
feature_dims["gsp"] = self.history_len
224+
225+
self.fusion_module = DynamicFusionModule(
226+
feature_dims=feature_dims,
227+
hidden_dim=fusion_hidden_dim,
228+
num_heads=num_fusion_heads,
229+
dropout=fusion_dropout,
230+
fusion_method=fusion_method,
231+
use_residual=True
232+
)
233+
234+
if use_cross_attention:
235+
self.cross_attention = CrossModalAttention(
236+
embed_dim=fusion_hidden_dim,
237+
num_heads=num_fusion_heads,
238+
dropout=fusion_dropout,
239+
num_modalities=len(feature_dims)
240+
)
241+
else:
242+
self.cross_attention = None
243+
244+
self.output_network = output_network(
245+
in_features=fusion_hidden_dim,
246+
out_features=self.num_output_features,
247+
)
248+
249+
self.save_hyperparameters()
250+
251+
def forward(self, x):
252+
253+
if self.adapt_batches:
254+
x = self._adapt_batch(x)
255+
256+
encoded_features = OrderedDict()
257+
258+
if self.include_sat:
259+
sat_data = x[BatchKey.satellite_actual][:, : self.sat_sequence_len]
260+
sat_data = torch.swapaxes(sat_data, 1, 2).float()
261+
262+
if self.add_image_embedding_channel:
263+
id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int()
264+
sat_data = self.sat_embed(sat_data, id)
265+
encoded_features["sat"] = self.sat_encoder(sat_data)
266+
267+
if self.include_nwp:
268+
for nwp_source in self.nwp_encoders_dict:
269+
nwp_data = x[BatchKey.nwp][nwp_source][NWPBatchKey.nwp].float()
270+
nwp_data = torch.swapaxes(nwp_data, 1, 2)
271+
nwp_data = torch.clip(nwp_data, min=-50, max=50)
272+
273+
if self.add_image_embedding_channel:
274+
id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int()
275+
nwp_data = self.nwp_embed_dict[nwp_source](nwp_data, id)
276+
277+
encoded_features[f"nwp/{nwp_source}"] = self.nwp_encoders_dict[nwp_source](nwp_data)
278+
279+
if self.include_pv:
280+
if self._target_key_name != "pv":
281+
encoded_features["pv"] = self.pv_encoder(x)
282+
else:
283+
x_tmp = x.copy()
284+
x_tmp[BatchKey.pv] = x_tmp[BatchKey.pv][:, : self.history_len + 1]
285+
encoded_features["pv"] = self.pv_encoder(x_tmp)
286+
287+
if self.include_gsp_yield_history:
288+
gsp_history = x[BatchKey.gsp][:, : self.history_len].float()
289+
encoded_features["gsp"] = gsp_history.reshape(gsp_history.shape[0], -1)
290+
291+
if self.embedding_dim:
292+
id = x[BatchKey[f"{self._target_key_name}_id"]][:, 0].int()
293+
encoded_features["embedding"] = self.embed(id)
294+
295+
if self.include_wind:
296+
if self._target_key_name != "wind":
297+
encoded_features["wind"] = self.wind_encoder(x)
298+
else:
299+
x_tmp = x.copy()
300+
x_tmp[BatchKey.wind] = x_tmp[BatchKey.wind][:, : self.history_len + 1]
301+
encoded_features["wind"] = self.wind_encoder(x_tmp)
302+
303+
if self.include_sensor:
304+
if self._target_key_name != "sensor":
305+
encoded_features["sensor"] = self.sensor_encoder(x)
306+
else:
307+
x_tmp = x.copy()
308+
x_tmp[BatchKey.sensor] = x_tmp[BatchKey.sensor][:, : self.history_len + 1]
309+
encoded_features["sensor"] = self.sensor_encoder(x_tmp)
310+
311+
if self.include_sun:
312+
sun = torch.cat(
313+
(
314+
x[BatchKey[f"{self._target_key_name}_solar_azimuth"]],
315+
x[BatchKey[f"{self._target_key_name}_solar_elevation"]],
316+
),
317+
dim=1,
318+
).float()
319+
encoded_features["sun"] = self.sun_fc1(sun)
320+
321+
322+
if self.include_time:
323+
time = torch.cat(
324+
(
325+
x[f"{self._target_key_name}_date_sin"],
326+
x[f"{self._target_key_name}_date_cos"],
327+
x[f"{self._target_key_name}_time_sin"],
328+
x[f"{self._target_key_name}_time_cos"],
329+
),
330+
dim=1,
331+
).float()
332+
encoded_features["time"] = self.time_fc1(time)
333+
334+
if self.cross_attention is not None and len(encoded_features) > 1:
335+
encoded_features = self.cross_attention(encoded_features)
336+
337+
fused_features = self.fusion_module(encoded_features)
338+
out = self.output_network(fused_features)
339+
340+
if self.use_quantile_regression:
341+
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))
342+
343+
return out

0 commit comments

Comments
 (0)