1
+ # dynamic_encoder.py
2
+
3
+ """ Dynamic fusion encoder implementation for multimodal learning """
4
+
5
+
6
+ from typing import Dict , Optional , List , Union
7
+ import torch
8
+ from torch import nn
9
+
10
+ from pvnet .models .multimodal .encoders .basic_blocks import AbstractNWPSatelliteEncoder
11
+ from pvnet .models .multimodal .fusion_blocks import DynamicFusionModule , ModalityGating
12
+ from pvnet .models .multimodal .attention_blocks import CrossModalAttention , SelfAttention
13
+ from pvnet .models .multimodal .encoders .encoders3d import DefaultPVNet2
14
+
15
+
16
+ class PVEncoder (nn .Module ):
17
+ """ Simplified PV encoder - maintains sequence dimension """
18
+
19
+ def __init__ (self , sequence_length : int , num_sites : int , out_features : int ):
20
+ super ().__init__ ()
21
+ self .sequence_length = sequence_length
22
+ self .num_sites = num_sites
23
+ self .out_features = out_features
24
+
25
+ # Process each timestep independently
26
+ self .encoder = nn .Sequential (
27
+ nn .Linear (num_sites , out_features ),
28
+ nn .LayerNorm (out_features ),
29
+ nn .ReLU (),
30
+ nn .Dropout (0.1 )
31
+ )
32
+
33
+ def forward (self , x ):
34
+ # x: [batch_size, sequence_length, num_sites]
35
+ batch_size = x .shape [0 ]
36
+ # Process each timestep
37
+ out = []
38
+ for t in range (self .sequence_length ):
39
+ out .append (self .encoder (x [:, t ]))
40
+ # Stack along sequence dimension
41
+ return torch .stack (out , dim = 1 ) # [batch_size, sequence_length, out_features]
42
+
43
+
44
+ class DynamicFusionEncoder (AbstractNWPSatelliteEncoder ):
45
+
46
+ """Encoder that implements dynamic fusion of satellite/NWP data streams"""
47
+
48
+ def __init__ (
49
+ self ,
50
+ sequence_length : int ,
51
+ image_size_pixels : int ,
52
+ modality_channels : Dict [str , int ],
53
+ out_features : int ,
54
+ modality_encoders : Dict [str , dict ],
55
+ cross_attention : Dict ,
56
+ modality_gating : Dict ,
57
+ dynamic_fusion : Dict ,
58
+ hidden_dim : int = 256 ,
59
+ fc_features : int = 128 ,
60
+ num_heads : int = 8 ,
61
+ dropout : float = 0.1 ,
62
+ use_gating : bool = True ,
63
+ use_cross_attention : bool = True
64
+ ):
65
+ """Dynamic fusion encoder for multimodal satellite/NWP data."""
66
+ super ().__init__ (
67
+ sequence_length = sequence_length ,
68
+ image_size_pixels = image_size_pixels ,
69
+ in_channels = sum (modality_channels .values ()),
70
+ out_features = out_features
71
+ )
72
+
73
+ self .modalities = list (modality_channels .keys ())
74
+ self .hidden_dim = hidden_dim
75
+ self .sequence_length = sequence_length
76
+
77
+ # Initialize modality-specific encoders
78
+ self .modality_encoders = nn .ModuleDict ()
79
+ for modality , config in modality_encoders .items ():
80
+ config = config .copy ()
81
+ if 'nwp' in modality or 'sat' in modality :
82
+ encoder = DefaultPVNet2 (
83
+ sequence_length = sequence_length ,
84
+ image_size_pixels = config .get ('image_size_pixels' , image_size_pixels ),
85
+ in_channels = modality_channels [modality ],
86
+ out_features = config .get ('out_features' , hidden_dim ),
87
+ number_of_conv3d_layers = config .get ('number_of_conv3d_layers' , 4 ),
88
+ conv3d_channels = config .get ('conv3d_channels' , 32 ),
89
+ batch_norm = config .get ('batch_norm' , True ),
90
+ fc_dropout = config .get ('fc_dropout' , 0.2 )
91
+ )
92
+
93
+ self .modality_encoders [modality ] = nn .Sequential (
94
+ encoder ,
95
+ nn .Unflatten (1 , (sequence_length , hidden_dim // sequence_length ))
96
+ )
97
+
98
+ elif modality == 'pv' :
99
+ self .modality_encoders [modality ] = PVEncoder (
100
+ sequence_length = sequence_length ,
101
+ num_sites = config ['num_sites' ],
102
+ out_features = hidden_dim
103
+ )
104
+
105
+ # Feature projections
106
+ self .feature_projections = nn .ModuleDict ({
107
+ modality : nn .Sequential (
108
+ nn .Linear (hidden_dim , hidden_dim ),
109
+ nn .LayerNorm (hidden_dim ),
110
+ nn .ReLU (),
111
+ nn .Dropout (dropout )
112
+ )
113
+ for modality in modality_channels .keys ()
114
+ })
115
+
116
+ # Optional modality gating
117
+ self .use_gating = use_gating
118
+ if use_gating :
119
+ gating_config = modality_gating .copy ()
120
+ gating_config ['feature_dims' ] = {
121
+ mod : hidden_dim for mod in modality_channels .keys ()
122
+ }
123
+ self .gating = ModalityGating (** gating_config )
124
+
125
+ # Optional cross-modal attention
126
+ self .use_cross_attention = use_cross_attention
127
+ if use_cross_attention :
128
+ attention_config = cross_attention .copy ()
129
+ attention_config ['embed_dim' ] = hidden_dim
130
+ self .cross_attention = CrossModalAttention (** attention_config )
131
+
132
+ # Dynamic fusion module
133
+ fusion_config = dynamic_fusion .copy ()
134
+ fusion_config ['feature_dims' ] = {
135
+ mod : hidden_dim for mod in modality_channels .keys ()
136
+ }
137
+ fusion_config ['hidden_dim' ] = hidden_dim
138
+ self .fusion_module = DynamicFusionModule (** fusion_config )
139
+
140
+ # Final output projection
141
+ self .final_block = nn .Sequential (
142
+ nn .Linear (hidden_dim * sequence_length , fc_features ),
143
+ nn .ELU (),
144
+ nn .Linear (fc_features , out_features ),
145
+ nn .ELU (),
146
+ )
147
+
148
+ def forward (
149
+ self ,
150
+ inputs : Dict [str , torch .Tensor ],
151
+ mask : Optional [torch .Tensor ] = None
152
+ ) -> torch .Tensor :
153
+ """Forward pass of the dynamic fusion encoder"""
154
+ # Initial encoding of each modality
155
+ encoded_features = {}
156
+ for modality , x in inputs .items ():
157
+ if modality not in self .modality_encoders :
158
+ continue
159
+
160
+ # Apply modality-specific encoder
161
+ # Output shape: [batch_size, sequence_length, hidden_dim]
162
+ encoded_features [modality ] = self .modality_encoders [modality ](x )
163
+
164
+ if not encoded_features :
165
+ raise ValueError ("No valid features found in inputs" )
166
+
167
+ # Apply modality gating if enabled
168
+ if self .use_gating :
169
+ encoded_features = self .gating (encoded_features )
170
+
171
+ # Apply cross-modal attention if enabled and more than one modality
172
+ if self .use_cross_attention and len (encoded_features ) > 1 :
173
+ encoded_features = self .cross_attention (encoded_features , mask )
174
+
175
+ # Apply dynamic fusion
176
+ fused_features = self .fusion_module (encoded_features , mask ) # [batch, sequence, hidden]
177
+
178
+ # Reshape and apply final projection
179
+ batch_size = fused_features .size (0 )
180
+ fused_features = fused_features .reshape (batch_size , - 1 ) # Flatten sequence dimension
181
+ output = self .final_block (fused_features )
182
+
183
+ return output
184
+
185
+
186
+ class DynamicResidualEncoder (DynamicFusionEncoder ):
187
+ """Dynamic fusion encoder with residual connections"""
188
+
189
+ def __init__ (self , * args , ** kwargs ):
190
+ super ().__init__ (* args , ** kwargs )
191
+
192
+ # Override feature projections to include residual connections
193
+ self .feature_projections = nn .ModuleDict ({
194
+ modality : nn .Sequential (
195
+ nn .Linear (self .hidden_dim , self .hidden_dim ),
196
+ nn .LayerNorm (self .hidden_dim ),
197
+ nn .ReLU (),
198
+ nn .Dropout (kwargs .get ('dropout' , 0.1 )),
199
+ nn .Linear (self .hidden_dim , self .hidden_dim ),
200
+ nn .LayerNorm (self .hidden_dim )
201
+ )
202
+ for modality in kwargs ['modality_channels' ].keys ()
203
+ })
0 commit comments