@@ -63,6 +63,7 @@ def __init__(
63
63
pv_interval_minutes : int = 5 ,
64
64
sat_interval_minutes : int = 5 ,
65
65
sensor_interval_minutes : int = 30 ,
66
+ image_embedding_dim : Optional [int ] = 318 ,
66
67
timestep_intervals_to_plot : Optional [list [int ]] = None ,
67
68
):
68
69
"""Neural network which combines information from different sources.
@@ -111,6 +112,7 @@ def __init__(
111
112
pv_interval_minutes: The interval between each sample of the PV data
112
113
sat_interval_minutes: The interval between each sample of the satellite data
113
114
sensor_interval_minutes: The interval between each sample of the sensor data
115
+ image_embedding_dim: The number of dimensions to use for the image embedding
114
116
timestep_intervals_to_plot: Intervals, in timesteps, to plot in addition to the full forecast
115
117
"""
116
118
@@ -155,7 +157,7 @@ def __init__(
155
157
)
156
158
if add_image_embedding_channel :
157
159
self .sat_embed = ImageEmbedding (
158
- 318 , self .sat_sequence_len , self .sat_encoder .image_size_pixels
160
+ image_embedding_dim , self .sat_sequence_len , self .sat_encoder .image_size_pixels
159
161
)
160
162
161
163
# Update num features
@@ -190,7 +192,7 @@ def __init__(
190
192
)
191
193
if add_image_embedding_channel :
192
194
self .nwp_embed_dict [nwp_source ] = ImageEmbedding (
193
- 318 , nwp_sequence_len , self .nwp_encoders_dict [nwp_source ].image_size_pixels
195
+ image_embedding_dim , nwp_sequence_len , self .nwp_encoders_dict [nwp_source ].image_size_pixels
194
196
)
195
197
196
198
# Update num features
@@ -229,7 +231,7 @@ def __init__(
229
231
fusion_input_features += self .sensor_encoder .out_features
230
232
231
233
if self .embedding_dim :
232
- self .embed = nn .Embedding (num_embeddings = 318 , embedding_dim = embedding_dim )
234
+ self .embed = nn .Embedding (num_embeddings = image_embedding_dim , embedding_dim = embedding_dim )
233
235
234
236
# Update num features
235
237
fusion_input_features += embedding_dim
0 commit comments