Skip to content

Commit

Permalink
Update site encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Sukh-P committed Jan 20, 2025
1 parent 19f2c1e commit bdd8732
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 27 deletions.
8 changes: 4 additions & 4 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder
from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
from pvnet.models.multimodal.multimodal_base import MultimodalBaseModel
from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractPVSitesEncoder
from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractSitesEncoder
from pvnet.optimizers import AbstractOptimizer


Expand Down Expand Up @@ -42,8 +42,8 @@ def __init__(
output_quantiles: Optional[list[float]] = None,
nwp_encoders_dict: Optional[dict[AbstractNWPSatelliteEncoder]] = None,
sat_encoder: Optional[AbstractNWPSatelliteEncoder] = None,
site_encoder: Optional[AbstractPVSitesEncoder] = None,
sensor_encoder: Optional[AbstractPVSitesEncoder] = None,
site_encoder: Optional[AbstractSitesEncoder] = None,
sensor_encoder: Optional[AbstractSitesEncoder] = None,
add_image_embedding_channel: bool = False,
include_gsp_yield_history: bool = True,
include_sun: bool = True,
Expand Down Expand Up @@ -222,7 +222,7 @@ def __init__(
assert site_history_minutes is not None

self.site_encoder = site_encoder(
sequence_length=site_history_minutes // site_interval_minutes + 1,
sequence_length=site_history_minutes // site_interval_minutes - 1,
target_key_to_use=self._target_key_name,
input_key_to_use="site",
)
Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/multimodal/site_encoders/basic_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn


class AbstractPVSitesEncoder(nn.Module, metaclass=ABCMeta):
class AbstractSitesEncoder(nn.Module, metaclass=ABCMeta):
"""Abstract class for encoder for output data from multiple PV sites.
The encoder will take an input of shape (batch_size, sequence_length, num_sites)
Expand Down
43 changes: 21 additions & 22 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from torch import nn

from pvnet.models.multimodal.linear_networks.networks import ResFCNet2
from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractPVSitesEncoder
from pvnet.models.multimodal.site_encoders.basic_blocks import AbstractSitesEncoder


class SimpleLearnedAggregator(AbstractPVSitesEncoder):
class SimpleLearnedAggregator(AbstractSitesEncoder):
"""A simple model which learns a different weighted-average across all PV sites for each GSP.
Each sequence from each site is independently encodeded through some dense layers wih skip-
Expand Down Expand Up @@ -107,7 +107,7 @@ def forward(self, x):
return x_out


class SingleAttentionNetwork(AbstractPVSitesEncoder):
class SingleAttentionNetwork(AbstractSitesEncoder):
"""A simple attention-based model with a single multihead attention layer
For the attention layer the query is based on the target alone, the key is based on the
Expand All @@ -127,26 +127,26 @@ def __init__(
kv_res_block_layers: int = 2,
use_id_in_value: bool = False,
target_id_dim: int = 318,
target_key_to_use: str = "gsp",
input_key_to_use: str = "pv",
target_key_to_use: str = "site",
input_key_to_use: str = "site",
num_channels: int = 1,
num_sites_in_inference: int = 1,
):
"""A simple attention-based model with a single multihead attention layer
Args:
sequence_length: The time sequence length of the data.
num_sites: Number of PV sites in the input data.
num_sites: Number of sites in the input data.
out_features: Number of output features. In this network this is also the embed and
value dimension in the multi-head attention layer.
kdim: The dimensions used the keys.
id_embed_dim: Number of dimensiosn used in the wind ID embedding layer(s).
id_embed_dim: Number of dimensiosn used in the site ID embedding layer(s).
num_heads: Number of parallel attention heads. Note that `out_features` will be split
across `num_heads` so `out_features` must be a multiple of `num_heads`.
n_kv_res_blocks: Number of residual blocks to use in the key and value encoders.
kv_res_block_layers: Number of fully-connected layers used in each residual block within
the key and value encoders.
use_id_in_value: Whether to use a PV ID embedding in network used to produce the
use_id_in_value: Whether to use a site ID embedding in network used to produce the
value for the attention layer.
target_id_dim: The number of unique IDs.
target_key_to_use: The key to use for the target in the attention layer.
Expand Down Expand Up @@ -206,9 +206,10 @@ def __init__(
)

def _encode_inputs(self, x):
# Shape: [batch size, sequence length, PV site] -> [8, 197, 1]
# Shape: [batch size, sequence length, number of sites] -> [8, 197, 1]
# Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23]
input_data = x[BatchKey[f"{self.input_key_to_use}"]]
input_data = x[f"{self.input_key_to_use}"]
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
if len(input_data.shape) == 4: # Has multiple channels
input_data = input_data[:, :, : self.sequence_length]
input_data = einops.rearrange(input_data, "b id s c -> b (s c) id")
Expand All @@ -223,42 +224,40 @@ def _encode_query(self, x):
# Select the first one
if self.target_key_to_use == "gsp":
# GSP seems to have a different structure
ids = x[BatchKey[f"{self.target_key_to_use}_id"]]
ids = x[f"{self.target_key_to_use}_id"]
else:
ids = x[BatchKey[f"{self.input_key_to_use}_id"]][:, 0]
ids = ids.squeeze().int()
if len(ids.shape) == 0: # Batch was squeezed down to nothing
ids = ids.unsqueeze(0)
ids = x[f"{self.input_key_to_use}_id"]
ids = ids.int()
query = self.target_id_embedding(ids).unsqueeze(1)
return query

def _encode_key(self, x):
site_seqs, batch_size = self._encode_inputs(x)

# wind ID embeddings are the same for each sample
# site ID embeddings are the same for each sample
site_id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
# Each concated (wind sequence, wind ID embedding) is processed with encoder
# Each concated (site sequence, site ID embedding) is processed with encoder
x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
key = self._key_encoder(x_seq_in)

# Reshape to [batch size, PV site, kdim]
# Reshape to [batch size, site, kdim]
key = key.unflatten(0, (batch_size, self.num_sites))
return key

def _encode_value(self, x):
site_seqs, batch_size = self._encode_inputs(x)

if self.use_id_in_value:
# wind ID embeddings are the same for each sample
# site ID embeddings are the same for each sample
site_id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1))
# Each concated (wind sequence, wind ID embedding) is processed with encoder
# Each concated (site sequence, site ID embedding) is processed with encoder
x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
else:
# Encode each PV sequence independently
# Encode each site sequence independently
x_seq_in = site_seqs.flatten(0, 1)
value = self._value_encoder(x_seq_in)

# Reshape to [batch size, PV site, vdim]
# Reshape to [batch size, site, vdim]
value = value.unflatten(0, (batch_size, self.num_sites))
return value

Expand Down

0 comments on commit bdd8732

Please sign in to comment.