Skip to content

Commit

Permalink
Include adjustable interval minutes
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Jan 19, 2024
1 parent 9cfc8a9 commit f28d77d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
8 changes: 4 additions & 4 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def __init__(
optimizer: AbstractOptimizer,
output_quantiles: Optional[list[float]] = None,
target_key: str = "gsp",
interval_minutes: int = 30,
):
"""Abtstract base class for PVNet submodels.
Expand All @@ -266,17 +267,16 @@ def __init__(
self.output_quantiles = output_quantiles

# Number of timestemps for 30 minutely data
# TODO Change, but make configurable, as India data is 15 minutely
self.history_len_30 = history_minutes // 15
self.forecast_len_30 = forecast_minutes // 15
self.history_len_30 = history_minutes // interval_minutes
self.forecast_len_30 = forecast_minutes // interval_minutes
# self.forecast_len_15 = forecast_minutes // 15
# self.history_len_15 = history_minutes // 15

self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30)

self._accumulated_metrics = MetricAccumulator()
self._accumulated_batches = BatchAccumulator(
key_to_keep=self._target_key
key_to_keep=self._target_key_name
)
self._accumulated_y_hat = PredAccumulator()

Expand Down
14 changes: 11 additions & 3 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
wind_history_minutes: Optional[int] = None,
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
target_key: str = "gsp",
interval_minutes: int = 30,
):
"""Neural network which combines information from different sources.
Expand Down Expand Up @@ -99,6 +100,7 @@ def __init__(
`history_minutes` if not provided.
optimizer: Optimizer factory function used for network.
target_key: The key of the target variable in the batch.
interval_minutes: The interval between each sample of the target data
"""

self.include_gsp_yield_history = include_gsp_yield_history
Expand All @@ -111,13 +113,15 @@ def __init__(
self.embedding_dim = embedding_dim
self.add_image_embedding_channel = add_image_embedding_channel
self.target_key_name = target_key
self.interval_minutes = interval_minutes

super().__init__(
history_minutes=history_minutes,
forecast_minutes=forecast_minutes,
optimizer=optimizer,
output_quantiles=output_quantiles,
target_key=BatchKey.gsp if target_key == "gsp" else BatchKey.wind,
target_key=target_key,
interval_minutes=interval_minutes
)

# Number of features expected by the output_network
Expand Down Expand Up @@ -278,8 +282,12 @@ def forward(self, x):
# *********************** Sensor Data ************************************
# add sensor yield history
if self.include_wind:
# sensor_history = x[BatchKey.sensor][:, : self.history_len_30].float()
modes["wind"] = self.wind_encoder(x)
if self.target_key_name != "wind":
modes["wind"] = self.wind_encoder(x)
else:
# Target is wind, so only take the history
wind_history = x[BatchKey.wind][:, : self.history_len_30].float()
modes["wind"] = self.wind_encoder(wind_history)

if self.include_sun:
sun = torch.cat(
Expand Down

0 comments on commit f28d77d

Please sign in to comment.