Skip to content

Commit ce1c672

Browse files
committed
Rename embedding dims
1 parent da912ad commit ce1c672

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pvnet/models/multimodal/multimodal.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
sat_interval_minutes: int = 5,
6565
sensor_interval_minutes: int = 30,
6666
wind_interval_minutes: int = 15,
67-
image_embedding_dim: Optional[int] = 318,
67+
num_embeddings: Optional[int] = 318,
6868
timestep_intervals_to_plot: Optional[list[int]] = None,
6969
):
7070
"""Neural network which combines information from different sources.
@@ -114,7 +114,7 @@ def __init__(
114114
pv_interval_minutes: The interval between each sample of the PV data
115115
sat_interval_minutes: The interval between each sample of the satellite data
116116
sensor_interval_minutes: The interval between each sample of the sensor data
117-
image_embedding_dim: The number of dimensions to use for the image embedding
117+
num_embeddings: The number of dimensions to use for the image embedding
118118
timestep_intervals_to_plot: Intervals, in timesteps, to plot in
119119
addition to the full forecast
120120
sensor_encoder: Encoder for sensor data
@@ -162,7 +162,7 @@ def __init__(
162162
)
163163
if add_image_embedding_channel:
164164
self.sat_embed = ImageEmbedding(
165-
image_embedding_dim, self.sat_sequence_len, self.sat_encoder.image_size_pixels
165+
num_embeddings, self.sat_sequence_len, self.sat_encoder.image_size_pixels
166166
)
167167

168168
# Update num features
@@ -197,7 +197,7 @@ def __init__(
197197
)
198198
if add_image_embedding_channel:
199199
self.nwp_embed_dict[nwp_source] = ImageEmbedding(
200-
image_embedding_dim,
200+
num_embeddings,
201201
nwp_sequence_len,
202202
self.nwp_encoders_dict[nwp_source].image_size_pixels,
203203
)
@@ -245,7 +245,7 @@ def __init__(
245245

246246
if self.embedding_dim:
247247
self.embed = nn.Embedding(
248-
num_embeddings=image_embedding_dim, embedding_dim=embedding_dim
248+
num_embeddings=num_embeddings, embedding_dim=embedding_dim
249249
)
250250

251251
# Update num features

0 commit comments

Comments
 (0)