Skip to content

Commit f28d77d

Browse files
committed
Include adjustable interval minutes
1 parent 9cfc8a9 commit f28d77d

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

pvnet/models/base_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def __init__(
240240
optimizer: AbstractOptimizer,
241241
output_quantiles: Optional[list[float]] = None,
242242
target_key: str = "gsp",
243+
interval_minutes: int = 30,
243244
):
244245
"""Abtstract base class for PVNet submodels.
245246
@@ -266,17 +267,16 @@ def __init__(
266267
self.output_quantiles = output_quantiles
267268

268269
# Number of timestemps for 30 minutely data
269-
# TODO Change, but make configurable, as India data is 15 minutely
270-
self.history_len_30 = history_minutes // 15
271-
self.forecast_len_30 = forecast_minutes // 15
270+
self.history_len_30 = history_minutes // interval_minutes
271+
self.forecast_len_30 = forecast_minutes // interval_minutes
272272
# self.forecast_len_15 = forecast_minutes // 15
273273
# self.history_len_15 = history_minutes // 15
274274

275275
self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30)
276276

277277
self._accumulated_metrics = MetricAccumulator()
278278
self._accumulated_batches = BatchAccumulator(
279-
key_to_keep=self._target_key
279+
key_to_keep=self._target_key_name
280280
)
281281
self._accumulated_y_hat = PredAccumulator()
282282

pvnet/models/multimodal/multimodal.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
wind_history_minutes: Optional[int] = None,
5959
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
6060
target_key: str = "gsp",
61+
interval_minutes: int = 30,
6162
):
6263
"""Neural network which combines information from different sources.
6364
@@ -99,6 +100,7 @@ def __init__(
99100
`history_minutes` if not provided.
100101
optimizer: Optimizer factory function used for network.
101102
target_key: The key of the target variable in the batch.
103+
interval_minutes: The interval between each sample of the target data
102104
"""
103105

104106
self.include_gsp_yield_history = include_gsp_yield_history
@@ -111,13 +113,15 @@ def __init__(
111113
self.embedding_dim = embedding_dim
112114
self.add_image_embedding_channel = add_image_embedding_channel
113115
self.target_key_name = target_key
116+
self.interval_minutes = interval_minutes
114117

115118
super().__init__(
116119
history_minutes=history_minutes,
117120
forecast_minutes=forecast_minutes,
118121
optimizer=optimizer,
119122
output_quantiles=output_quantiles,
120-
target_key=BatchKey.gsp if target_key == "gsp" else BatchKey.wind,
123+
target_key=target_key,
124+
interval_minutes=interval_minutes
121125
)
122126

123127
# Number of features expected by the output_network
@@ -278,8 +282,12 @@ def forward(self, x):
278282
# *********************** Sensor Data ************************************
279283
# add sensor yield history
280284
if self.include_wind:
281-
# sensor_history = x[BatchKey.sensor][:, : self.history_len_30].float()
282-
modes["wind"] = self.wind_encoder(x)
285+
if self.target_key_name != "wind":
286+
modes["wind"] = self.wind_encoder(x)
287+
else:
288+
# Target is wind, so only take the history
289+
wind_history = x[BatchKey.wind][:, : self.history_len_30].float()
290+
modes["wind"] = self.wind_encoder(wind_history)
283291

284292
if self.include_sun:
285293
sun = torch.cat(

0 commit comments

Comments
 (0)