@@ -58,6 +58,7 @@ def __init__(
5858 wind_history_minutes : Optional [int ] = None ,
5959 optimizer : AbstractOptimizer = pvnet .optimizers .Adam (),
6060 target_key : str = "gsp" ,
61+ interval_minutes : int = 30 ,
6162 ):
6263 """Neural network which combines information from different sources.
6364
@@ -99,6 +100,7 @@ def __init__(
99100 `history_minutes` if not provided.
100101 optimizer: Optimizer factory function used for network.
101102 target_key: The key of the target variable in the batch.
103+ interval_minutes: The interval between each sample of the target data
102104 """
103105
104106 self .include_gsp_yield_history = include_gsp_yield_history
@@ -111,13 +113,15 @@ def __init__(
111113 self .embedding_dim = embedding_dim
112114 self .add_image_embedding_channel = add_image_embedding_channel
113115 self .target_key_name = target_key
116+ self .interval_minutes = interval_minutes
114117
115118 super ().__init__ (
116119 history_minutes = history_minutes ,
117120 forecast_minutes = forecast_minutes ,
118121 optimizer = optimizer ,
119122 output_quantiles = output_quantiles ,
120- target_key = BatchKey .gsp if target_key == "gsp" else BatchKey .wind ,
123+ target_key = target_key ,
124+ interval_minutes = interval_minutes
121125 )
122126
123127 # Number of features expected by the output_network
@@ -278,8 +282,12 @@ def forward(self, x):
278282 # *********************** Sensor Data ************************************
279283 # add sensor yield history
280284 if self .include_wind :
281- # sensor_history = x[BatchKey.sensor][:, : self.history_len_30].float()
282- modes ["wind" ] = self .wind_encoder (x )
285+ if self .target_key_name != "wind" :
286+ modes ["wind" ] = self .wind_encoder (x )
287+ else :
288+ # Target is wind, so only take the history
289+ wind_history = x [BatchKey .wind ][:, : self .history_len_30 ].float ()
290+ modes ["wind" ] = self .wind_encoder (wind_history )
283291
284292 if self .include_sun :
285293 sun = torch .cat (
0 commit comments