Skip to content

Commit

Permalink
Make image embedding size configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Feb 6, 2024
1 parent f4a9a0d commit bef8b48
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
pv_interval_minutes: int = 5,
sat_interval_minutes: int = 5,
sensor_interval_minutes: int = 30,
image_embedding_dim: Optional[int] = 318,
timestep_intervals_to_plot: Optional[list[int]] = None,
):
"""Neural network which combines information from different sources.
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
pv_interval_minutes: The interval between each sample of the PV data
sat_interval_minutes: The interval between each sample of the satellite data
sensor_interval_minutes: The interval between each sample of the sensor data
image_embedding_dim: The number of dimensions to use for the image embedding
timestep_intervals_to_plot: Intervals, in timesteps, to plot in addition to the full forecast
"""

Expand Down Expand Up @@ -155,7 +157,7 @@ def __init__(
)
if add_image_embedding_channel:
self.sat_embed = ImageEmbedding(
318, self.sat_sequence_len, self.sat_encoder.image_size_pixels
image_embedding_dim, self.sat_sequence_len, self.sat_encoder.image_size_pixels
)

# Update num features
Expand Down Expand Up @@ -190,7 +192,7 @@ def __init__(
)
if add_image_embedding_channel:
self.nwp_embed_dict[nwp_source] = ImageEmbedding(
318, nwp_sequence_len, self.nwp_encoders_dict[nwp_source].image_size_pixels
image_embedding_dim, nwp_sequence_len, self.nwp_encoders_dict[nwp_source].image_size_pixels
)

# Update num features
Expand Down Expand Up @@ -229,7 +231,7 @@ def __init__(
fusion_input_features += self.sensor_encoder.out_features

if self.embedding_dim:
self.embed = nn.Embedding(num_embeddings=318, embedding_dim=embedding_dim)
self.embed = nn.Embedding(num_embeddings=image_embedding_dim, embedding_dim=embedding_dim)

# Update num features
fusion_input_features += embedding_dim
Expand Down

0 comments on commit bef8b48

Please sign in to comment.