22
33"""
44
5+ import einops
56import torch
67from ocf_datapipes .batch import BatchKey
78from torch import nn
@@ -128,6 +129,8 @@ def __init__(
128129 target_id_dim : int = 318 ,
129130 target_key_to_use : str = "gsp" ,
130131 input_key_to_use : str = "pv" ,
132+ num_channels : int = 1 ,
133+ num_sites_in_inference : int = 1 ,
131134 ):
132135 """A simple attention-based model with a single multihead attention layer
133136
@@ -148,6 +151,13 @@ def __init__(
148151 target_id_dim: The number of unique IDs.
149152 target_key_to_use: The key to use for the target in the attention layer.
150153 input_key_to_use: The key to use for the input in the attention layer.
154+ num_channels: Number of channels in the input data. For single site generation,
155+ this will be 1, as there is not channel dimension, for Sensors,
156+ this will probably be higher than that
157+ num_sites_in_inference: Number of sites to use in inference.
158+ This is used to determine the number of sites to use in the
159+ attention layer, for a single site, 1 works, while for multiple sites
160+ (such as multiple sensors), this would be higher than that
151161
152162 """
153163 super ().__init__ (sequence_length , num_sites , out_features )
@@ -158,15 +168,18 @@ def __init__(
158168 self .use_id_in_value = use_id_in_value
159169 self .target_key_to_use = target_key_to_use
160170 self .input_key_to_use = input_key_to_use
171+ self .num_channels = num_channels
172+ self .num_sites_in_inference = num_sites_in_inference
161173
162174 if use_id_in_value :
163175 self .value_id_embedding = nn .Embedding (num_sites , id_embed_dim )
164176
165177 self ._value_encoder = nn .Sequential (
166178 ResFCNet2 (
167- in_features = sequence_length + int (use_id_in_value ) * id_embed_dim ,
179+ in_features = sequence_length * self .num_channels
180+ + int (use_id_in_value ) * id_embed_dim ,
168181 out_features = out_features ,
169- fc_hidden_features = sequence_length ,
182+ fc_hidden_features = sequence_length * self . num_channels ,
170183 n_res_blocks = n_kv_res_blocks ,
171184 res_block_layers = kv_res_block_layers ,
172185 dropout_frac = 0 ,
@@ -175,9 +188,9 @@ def __init__(
175188
176189 self ._key_encoder = nn .Sequential (
177190 ResFCNet2 (
178- in_features = sequence_length + id_embed_dim ,
191+ in_features = id_embed_dim + sequence_length * self . num_channels ,
179192 out_features = kdim ,
180- fc_hidden_features = id_embed_dim + sequence_length ,
193+ fc_hidden_features = id_embed_dim + sequence_length * self . num_channels ,
181194 n_res_blocks = n_kv_res_blocks ,
182195 res_block_layers = kv_res_block_layers ,
183196 dropout_frac = 0 ,
@@ -192,6 +205,20 @@ def __init__(
192205 batch_first = True ,
193206 )
194207
208+ def _encode_inputs (self , x ):
209+ # Shape: [batch size, sequence length, PV site] -> [8, 197, 1]
210+ # Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23]
211+ input_data = x [BatchKey [f"{ self .input_key_to_use } " ]]
212+ if len (input_data .shape ) == 4 : # Has multiple channels
213+ input_data = input_data [:, :, : self .sequence_length ]
214+ input_data = einops .rearrange (input_data , "b id s c -> b (s c) id" )
215+ else :
216+ input_data = input_data [:, : self .sequence_length ]
217+ site_seqs = input_data .float ()
218+ batch_size = site_seqs .shape [0 ]
219+ site_seqs = site_seqs .swapaxes (1 , 2 ) # [batch size, Site ID, sequence length]
220+ return site_seqs , batch_size
221+
195222 def _encode_query (self , x ):
196223 # Select the first one
197224 if self .target_key_to_use == "gsp" :
@@ -206,34 +233,29 @@ def _encode_query(self, x):
206233 return query
207234
208235 def _encode_key (self , x ):
209- # Shape: [batch size, sequence length, PV site]
210- site_seqs = x [BatchKey [f"{ self .input_key_to_use } " ]][:, : self .sequence_length ].float ()
211- batch_size = site_seqs .shape [0 ]
236+ site_seqs , batch_size = self ._encode_inputs (x )
212237
213238 # wind ID embeddings are the same for each sample
214239 site_id_embed = torch .tile (self .site_id_embedding (self ._ids ), (batch_size , 1 , 1 ))
215240 # Each concated (wind sequence, wind ID embedding) is processed with encoder
216- x_seq_in = torch .cat ((site_seqs . swapaxes ( 1 , 2 ) , site_id_embed ), dim = 2 ).flatten (0 , 1 )
241+ x_seq_in = torch .cat ((site_seqs , site_id_embed ), dim = 2 ).flatten (0 , 1 )
217242 key = self ._key_encoder (x_seq_in )
218243
219244 # Reshape to [batch size, PV site, kdim]
220245 key = key .unflatten (0 , (batch_size , self .num_sites ))
221246 return key
222247
223248 def _encode_value (self , x ):
224- # Shape: [batch size, sequence length, PV site]
225- site_seqs = x [BatchKey [f"{ self .input_key_to_use } " ]][:, : self .sequence_length ].float ()
226- batch_size = site_seqs .shape [0 ]
249+ site_seqs , batch_size = self ._encode_inputs (x )
227250
228251 if self .use_id_in_value :
229252 # wind ID embeddings are the same for each sample
230253 site_id_embed = torch .tile (self .value_id_embedding (self ._ids ), (batch_size , 1 , 1 ))
231254 # Each concated (wind sequence, wind ID embedding) is processed with encoder
232- x_seq_in = torch .cat ((site_seqs . swapaxes ( 1 , 2 ) , site_id_embed ), dim = 2 ).flatten (0 , 1 )
255+ x_seq_in = torch .cat ((site_seqs , site_id_embed ), dim = 2 ).flatten (0 , 1 )
233256 else :
234257 # Encode each PV sequence independently
235- x_seq_in = site_seqs .swapaxes (1 , 2 ).flatten (0 , 1 )
236-
258+ x_seq_in = site_seqs .flatten (0 , 1 )
237259 value = self ._value_encoder (x_seq_in )
238260
239261 # Reshape to [batch size, PV site, vdim]
0 commit comments