@@ -64,7 +64,7 @@ def __init__(
64
64
sat_interval_minutes : int = 5 ,
65
65
sensor_interval_minutes : int = 30 ,
66
66
wind_interval_minutes : int = 15 ,
67
- image_embedding_dim : Optional [int ] = 318 ,
67
+ num_embeddings : Optional [int ] = 318 ,
68
68
timestep_intervals_to_plot : Optional [list [int ]] = None ,
69
69
):
70
70
"""Neural network which combines information from different sources.
@@ -114,7 +114,7 @@ def __init__(
114
114
pv_interval_minutes: The interval between each sample of the PV data
115
115
sat_interval_minutes: The interval between each sample of the satellite data
116
116
sensor_interval_minutes: The interval between each sample of the sensor data
117
- image_embedding_dim : The number of dimensions to use for the image embedding
117
+ num_embeddings : The number of dimensions to use for the image embedding
118
118
timestep_intervals_to_plot: Intervals, in timesteps, to plot in
119
119
addition to the full forecast
120
120
sensor_encoder: Encoder for sensor data
@@ -162,7 +162,7 @@ def __init__(
162
162
)
163
163
if add_image_embedding_channel :
164
164
self .sat_embed = ImageEmbedding (
165
- image_embedding_dim , self .sat_sequence_len , self .sat_encoder .image_size_pixels
165
+ num_embeddings , self .sat_sequence_len , self .sat_encoder .image_size_pixels
166
166
)
167
167
168
168
# Update num features
@@ -197,7 +197,7 @@ def __init__(
197
197
)
198
198
if add_image_embedding_channel :
199
199
self .nwp_embed_dict [nwp_source ] = ImageEmbedding (
200
- image_embedding_dim ,
200
+ num_embeddings ,
201
201
nwp_sequence_len ,
202
202
self .nwp_encoders_dict [nwp_source ].image_size_pixels ,
203
203
)
@@ -245,7 +245,7 @@ def __init__(
245
245
246
246
if self .embedding_dim :
247
247
self .embed = nn .Embedding (
248
- num_embeddings = image_embedding_dim , embedding_dim = embedding_dim
248
+ num_embeddings = num_embeddings , embedding_dim = embedding_dim
249
249
)
250
250
251
251
# Update num features
0 commit comments