@@ -58,6 +58,7 @@ def __init__(
58
58
wind_history_minutes : Optional [int ] = None ,
59
59
optimizer : AbstractOptimizer = pvnet .optimizers .Adam (),
60
60
target_key : str = "gsp" ,
61
+ interval_minutes : int = 30 ,
61
62
):
62
63
"""Neural network which combines information from different sources.
63
64
@@ -99,6 +100,7 @@ def __init__(
99
100
`history_minutes` if not provided.
100
101
optimizer: Optimizer factory function used for network.
101
102
target_key: The key of the target variable in the batch.
103
+ interval_minutes: The interval between each sample of the target data
102
104
"""
103
105
104
106
self .include_gsp_yield_history = include_gsp_yield_history
@@ -111,13 +113,15 @@ def __init__(
111
113
self .embedding_dim = embedding_dim
112
114
self .add_image_embedding_channel = add_image_embedding_channel
113
115
self .target_key_name = target_key
116
+ self .interval_minutes = interval_minutes
114
117
115
118
super ().__init__ (
116
119
history_minutes = history_minutes ,
117
120
forecast_minutes = forecast_minutes ,
118
121
optimizer = optimizer ,
119
122
output_quantiles = output_quantiles ,
120
- target_key = BatchKey .gsp if target_key == "gsp" else BatchKey .wind ,
123
+ target_key = target_key ,
124
+ interval_minutes = interval_minutes
121
125
)
122
126
123
127
# Number of features expected by the output_network
@@ -278,8 +282,12 @@ def forward(self, x):
278
282
# *********************** Sensor Data ************************************
279
283
# add sensor yield history
280
284
if self .include_wind :
281
- # sensor_history = x[BatchKey.sensor][:, : self.history_len_30].float()
282
- modes ["wind" ] = self .wind_encoder (x )
285
+ if self .target_key_name != "wind" :
286
+ modes ["wind" ] = self .wind_encoder (x )
287
+ else :
288
+ # Target is wind, so only take the history
289
+ wind_history = x [BatchKey .wind ][:, : self .history_len_30 ].float ()
290
+ modes ["wind" ] = self .wind_encoder (wind_history )
283
291
284
292
if self .include_sun :
285
293
sun = torch .cat (
0 commit comments