2
2
3
3
"""
4
4
5
+ import einops
5
6
import torch
6
7
from ocf_datapipes .batch import BatchKey
7
8
from torch import nn
@@ -128,6 +129,8 @@ def __init__(
128
129
target_id_dim : int = 318 ,
129
130
target_key_to_use : str = "gsp" ,
130
131
input_key_to_use : str = "pv" ,
132
+ num_channels : int = 1 ,
133
+ num_sites_in_inference : int = 1 ,
131
134
):
132
135
"""A simple attention-based model with a single multihead attention layer
133
136
@@ -148,6 +151,13 @@ def __init__(
148
151
target_id_dim: The number of unique IDs.
149
152
target_key_to_use: The key to use for the target in the attention layer.
150
153
input_key_to_use: The key to use for the input in the attention layer.
154
+ num_channels: Number of channels in the input data. For single site generation,
155
+ this will be 1, as there is not channel dimension, for Sensors,
156
+ this will probably be higher than that
157
+ num_sites_in_inference: Number of sites to use in inference.
158
+ This is used to determine the number of sites to use in the
159
+ attention layer, for a single site, 1 works, while for multiple sites
160
+ (such as multiple sensors), this would be higher than that
151
161
152
162
"""
153
163
super ().__init__ (sequence_length , num_sites , out_features )
@@ -158,15 +168,18 @@ def __init__(
158
168
self .use_id_in_value = use_id_in_value
159
169
self .target_key_to_use = target_key_to_use
160
170
self .input_key_to_use = input_key_to_use
171
+ self .num_channels = num_channels
172
+ self .num_sites_in_inference = num_sites_in_inference
161
173
162
174
if use_id_in_value :
163
175
self .value_id_embedding = nn .Embedding (num_sites , id_embed_dim )
164
176
165
177
self ._value_encoder = nn .Sequential (
166
178
ResFCNet2 (
167
- in_features = sequence_length + int (use_id_in_value ) * id_embed_dim ,
179
+ in_features = sequence_length * self .num_channels
180
+ + int (use_id_in_value ) * id_embed_dim ,
168
181
out_features = out_features ,
169
- fc_hidden_features = sequence_length ,
182
+ fc_hidden_features = sequence_length * self . num_channels ,
170
183
n_res_blocks = n_kv_res_blocks ,
171
184
res_block_layers = kv_res_block_layers ,
172
185
dropout_frac = 0 ,
@@ -175,9 +188,9 @@ def __init__(
175
188
176
189
self ._key_encoder = nn .Sequential (
177
190
ResFCNet2 (
178
- in_features = sequence_length + id_embed_dim ,
191
+ in_features = id_embed_dim + sequence_length * self . num_channels ,
179
192
out_features = kdim ,
180
- fc_hidden_features = id_embed_dim + sequence_length ,
193
+ fc_hidden_features = id_embed_dim + sequence_length * self . num_channels ,
181
194
n_res_blocks = n_kv_res_blocks ,
182
195
res_block_layers = kv_res_block_layers ,
183
196
dropout_frac = 0 ,
@@ -192,6 +205,20 @@ def __init__(
192
205
batch_first = True ,
193
206
)
194
207
208
+ def _encode_inputs (self , x ):
209
+ # Shape: [batch size, sequence length, PV site] -> [8, 197, 1]
210
+ # Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23]
211
+ input_data = x [BatchKey [f"{ self .input_key_to_use } " ]]
212
+ if len (input_data .shape ) == 4 : # Has multiple channels
213
+ input_data = input_data [:, :, : self .sequence_length ]
214
+ input_data = einops .rearrange (input_data , "b id s c -> b (s c) id" )
215
+ else :
216
+ input_data = input_data [:, : self .sequence_length ]
217
+ site_seqs = input_data .float ()
218
+ batch_size = site_seqs .shape [0 ]
219
+ site_seqs = site_seqs .swapaxes (1 , 2 ) # [batch size, Site ID, sequence length]
220
+ return site_seqs , batch_size
221
+
195
222
def _encode_query (self , x ):
196
223
# Select the first one
197
224
if self .target_key_to_use == "gsp" :
@@ -206,34 +233,29 @@ def _encode_query(self, x):
206
233
return query
207
234
208
235
def _encode_key (self , x ):
209
- # Shape: [batch size, sequence length, PV site]
210
- site_seqs = x [BatchKey [f"{ self .input_key_to_use } " ]][:, : self .sequence_length ].float ()
211
- batch_size = site_seqs .shape [0 ]
236
+ site_seqs , batch_size = self ._encode_inputs (x )
212
237
213
238
# wind ID embeddings are the same for each sample
214
239
site_id_embed = torch .tile (self .site_id_embedding (self ._ids ), (batch_size , 1 , 1 ))
215
240
# Each concated (wind sequence, wind ID embedding) is processed with encoder
216
- x_seq_in = torch .cat ((site_seqs . swapaxes ( 1 , 2 ) , site_id_embed ), dim = 2 ).flatten (0 , 1 )
241
+ x_seq_in = torch .cat ((site_seqs , site_id_embed ), dim = 2 ).flatten (0 , 1 )
217
242
key = self ._key_encoder (x_seq_in )
218
243
219
244
# Reshape to [batch size, PV site, kdim]
220
245
key = key .unflatten (0 , (batch_size , self .num_sites ))
221
246
return key
222
247
223
248
def _encode_value (self , x ):
224
- # Shape: [batch size, sequence length, PV site]
225
- site_seqs = x [BatchKey [f"{ self .input_key_to_use } " ]][:, : self .sequence_length ].float ()
226
- batch_size = site_seqs .shape [0 ]
249
+ site_seqs , batch_size = self ._encode_inputs (x )
227
250
228
251
if self .use_id_in_value :
229
252
# wind ID embeddings are the same for each sample
230
253
site_id_embed = torch .tile (self .value_id_embedding (self ._ids ), (batch_size , 1 , 1 ))
231
254
# Each concated (wind sequence, wind ID embedding) is processed with encoder
232
- x_seq_in = torch .cat ((site_seqs . swapaxes ( 1 , 2 ) , site_id_embed ), dim = 2 ).flatten (0 , 1 )
255
+ x_seq_in = torch .cat ((site_seqs , site_id_embed ), dim = 2 ).flatten (0 , 1 )
233
256
else :
234
257
# Encode each PV sequence independently
235
- x_seq_in = site_seqs .swapaxes (1 , 2 ).flatten (0 , 1 )
236
-
258
+ x_seq_in = site_seqs .flatten (0 , 1 )
237
259
value = self ._value_encoder (x_seq_in )
238
260
239
261
# Reshape to [batch size, PV site, vdim]
0 commit comments