Skip to content

Commit bef8b48

Browse files
committed
Make image embedding size configurable
1 parent f4a9a0d commit bef8b48

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pvnet/models/multimodal/multimodal.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
pv_interval_minutes: int = 5,
6464
sat_interval_minutes: int = 5,
6565
sensor_interval_minutes: int = 30,
66+
image_embedding_dim: Optional[int] = 318,
6667
timestep_intervals_to_plot: Optional[list[int]] = None,
6768
):
6869
"""Neural network which combines information from different sources.
@@ -111,6 +112,7 @@ def __init__(
111112
pv_interval_minutes: The interval between each sample of the PV data
112113
sat_interval_minutes: The interval between each sample of the satellite data
113114
sensor_interval_minutes: The interval between each sample of the sensor data
115+
image_embedding_dim: The number of dimensions to use for the image embedding
114116
timestep_intervals_to_plot: Intervals, in timesteps, to plot in addition to the full forecast
115117
"""
116118

@@ -155,7 +157,7 @@ def __init__(
155157
)
156158
if add_image_embedding_channel:
157159
self.sat_embed = ImageEmbedding(
158-
318, self.sat_sequence_len, self.sat_encoder.image_size_pixels
160+
image_embedding_dim, self.sat_sequence_len, self.sat_encoder.image_size_pixels
159161
)
160162

161163
# Update num features
@@ -190,7 +192,7 @@ def __init__(
190192
)
191193
if add_image_embedding_channel:
192194
self.nwp_embed_dict[nwp_source] = ImageEmbedding(
193-
318, nwp_sequence_len, self.nwp_encoders_dict[nwp_source].image_size_pixels
195+
image_embedding_dim, nwp_sequence_len, self.nwp_encoders_dict[nwp_source].image_size_pixels
194196
)
195197

196198
# Update num features
@@ -229,7 +231,7 @@ def __init__(
229231
fusion_input_features += self.sensor_encoder.out_features
230232

231233
if self.embedding_dim:
232-
self.embed = nn.Embedding(num_embeddings=318, embedding_dim=embedding_dim)
234+
self.embed = nn.Embedding(num_embeddings=image_embedding_dim, embedding_dim=embedding_dim)
233235

234236
# Update num features
235237
fusion_input_features += embedding_dim

0 commit comments

Comments
 (0)