From bdd873264c00a90092c2202d4850aacb1f6f665a Mon Sep 17 00:00:00 2001 From: Sukh-P Date: Mon, 20 Jan 2025 11:21:46 +0000 Subject: [PATCH] Update site encoder --- pvnet/models/multimodal/multimodal.py | 8 ++-- .../multimodal/site_encoders/basic_blocks.py | 2 +- .../multimodal/site_encoders/encoders.py | 43 +++++++++---------- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index b6e0e8e2..3482ba0c 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -13,7 +13,7 @@ from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork from pvnet.models.multimodal.multimodal_base import MultimodalBaseModel -from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractPVSitesEncoder +from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractSitesEncoder from pvnet.optimizers import AbstractOptimizer @@ -42,8 +42,8 @@ def __init__( output_quantiles: Optional[list[float]] = None, nwp_encoders_dict: Optional[dict[AbstractNWPSatelliteEncoder]] = None, sat_encoder: Optional[AbstractNWPSatelliteEncoder] = None, - site_encoder: Optional[AbstractPVSitesEncoder] = None, - sensor_encoder: Optional[AbstractPVSitesEncoder] = None, + site_encoder: Optional[AbstractSitesEncoder] = None, + sensor_encoder: Optional[AbstractSitesEncoder] = None, add_image_embedding_channel: bool = False, include_gsp_yield_history: bool = True, include_sun: bool = True, @@ -222,7 +222,7 @@ def __init__( assert site_history_minutes is not None self.site_encoder = site_encoder( - sequence_length=site_history_minutes // site_interval_minutes + 1, + sequence_length=site_history_minutes // site_interval_minutes - 1, target_key_to_use=self._target_key_name, input_key_to_use="site", ) diff --git a/pvnet/models/multimodal/site_encoders/basic_blocks.py b/pvnet/models/multimodal/site_encoders/basic_blocks.py index b20835f1..525ba74a 100644 --- a/pvnet/models/multimodal/site_encoders/basic_blocks.py +++ b/pvnet/models/multimodal/site_encoders/basic_blocks.py @@ -4,7 +4,7 @@ from torch import nn -class AbstractPVSitesEncoder(nn.Module, metaclass=ABCMeta): +class AbstractSitesEncoder(nn.Module, metaclass=ABCMeta): """Abstract class for encoder for output data from multiple PV sites. The encoder will take an input of shape (batch_size, sequence_length, num_sites) diff --git a/pvnet/models/multimodal/site_encoders/encoders.py b/pvnet/models/multimodal/site_encoders/encoders.py index ecd05709..67b5c990 100644 --- a/pvnet/models/multimodal/site_encoders/encoders.py +++ b/pvnet/models/multimodal/site_encoders/encoders.py @@ -8,10 +8,10 @@ from torch import nn from pvnet.models.multimodal.linear_networks.networks import ResFCNet2 -from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractPVSitesEncoder +from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractSitesEncoder -class SimpleLearnedAggregator(AbstractPVSitesEncoder): +class SimpleLearnedAggregator(AbstractSitesEncoder): """A simple model which learns a different weighted-average across all PV sites for each GSP. Each sequence from each site is independently encodeded through some dense layers wih skip- @@ -107,7 +107,7 @@ def forward(self, x): return x_out -class SingleAttentionNetwork(AbstractPVSitesEncoder): +class SingleAttentionNetwork(AbstractSitesEncoder): """A simple attention-based model with a single multihead attention layer For the attention layer the query is based on the target alone, the key is based on the @@ -127,8 +127,8 @@ def __init__( kv_res_block_layers: int = 2, use_id_in_value: bool = False, target_id_dim: int = 318, - target_key_to_use: str = "gsp", - input_key_to_use: str = "pv", + target_key_to_use: str = "site", + input_key_to_use: str = "site", num_channels: int = 1, num_sites_in_inference: int = 1, ): @@ -136,17 +136,17 @@ def __init__( Args: sequence_length: The time sequence length of the data. - num_sites: Number of PV sites in the input data. + num_sites: Number of sites in the input data. out_features: Number of output features. In this network this is also the embed and value dimension in the multi-head attention layer. kdim: The dimensions used the keys. - id_embed_dim: Number of dimensiosn used in the wind ID embedding layer(s). + id_embed_dim: Number of dimensiosn used in the site ID embedding layer(s). num_heads: Number of parallel attention heads. Note that `out_features` will be split across `num_heads` so `out_features` must be a multiple of `num_heads`. n_kv_res_blocks: Number of residual blocks to use in the key and value encoders. kv_res_block_layers: Number of fully-connected layers used in each residual block within the key and value encoders. - use_id_in_value: Whether to use a PV ID embedding in network used to produce the + use_id_in_value: Whether to use a site ID embedding in network used to produce the value for the attention layer. target_id_dim: The number of unique IDs. target_key_to_use: The key to use for the target in the attention layer. @@ -206,9 +206,10 @@ def __init__( ) def _encode_inputs(self, x): - # Shape: [batch size, sequence length, PV site] -> [8, 197, 1] + # Shape: [batch size, sequence length, number of sites] -> [8, 197, 1] # Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23] - input_data = x[BatchKey[f"{self.input_key_to_use}"]] + input_data = x[f"{self.input_key_to_use}"] + input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D if len(input_data.shape) == 4: # Has multiple channels input_data = input_data[:, :, : self.sequence_length] input_data = einops.rearrange(input_data, "b id s c -> b (s c) id") @@ -223,25 +224,23 @@ def _encode_query(self, x): # Select the first one if self.target_key_to_use == "gsp": # GSP seems to have a different structure - ids = x[BatchKey[f"{self.target_key_to_use}_id"]] + ids = x[f"{self.target_key_to_use}_id"] else: - ids = x[BatchKey[f"{self.input_key_to_use}_id"]][:, 0] - ids = ids.squeeze().int() - if len(ids.shape) == 0: # Batch was squeezed down to nothing - ids = ids.unsqueeze(0) + ids = x[f"{self.input_key_to_use}_id"] + ids = ids.int() query = self.target_id_embedding(ids).unsqueeze(1) return query def _encode_key(self, x): site_seqs, batch_size = self._encode_inputs(x) - # wind ID embeddings are the same for each sample + # site ID embeddings are the same for each sample site_id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1)) - # Each concated (wind sequence, wind ID embedding) is processed with encoder + # Each concated (site sequence, site ID embedding) is processed with encoder x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1) key = self._key_encoder(x_seq_in) - # Reshape to [batch size, PV site, kdim] + # Reshape to [batch size, site, kdim] key = key.unflatten(0, (batch_size, self.num_sites)) return key @@ -249,16 +248,16 @@ def _encode_value(self, x): site_seqs, batch_size = self._encode_inputs(x) if self.use_id_in_value: - # wind ID embeddings are the same for each sample + # site ID embeddings are the same for each sample site_id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1)) - # Each concated (wind sequence, wind ID embedding) is processed with encoder + # Each concated (site sequence, site ID embedding) is processed with encoder x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1) else: - # Encode each PV sequence independently + # Encode each site sequence independently x_seq_in = site_seqs.flatten(0, 1) value = self._value_encoder(x_seq_in) - # Reshape to [batch size, PV site, vdim] + # Reshape to [batch size, site, vdim] value = value.unflatten(0, (batch_size, self.num_sites)) return value