-
-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathlast_value.py
43 lines (32 loc) · 1.4 KB
/
last_value.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""Persistence model"""
from ocf_datapipes.batch import BatchKey
import pvnet
from pvnet.models.base_model import BaseModel
from pvnet.optimizers import AbstractOptimizer
class Model(BaseModel):
"""Simple baseline model that takes the last gsp yield value and copies it forward."""
name = "last_value"
def __init__(
self,
forecast_minutes: int = 12,
history_minutes: int = 6,
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
):
"""Simple baseline model that takes the last gsp yield value and copies it forward.
Args:
history_minutes (int): Length of the GSP history period in minutes
forecast_minutes (int): Length of the GSP forecast period in minutes
optimizer (AbstractOptimizer): Optimizer
"""
super().__init__(history_minutes, forecast_minutes, optimizer)
self.save_hyperparameters()
def forward(self, x: dict):
"""Run model forward on dict batch of data"""
# Shape: batch_size, seq_length, n_sites
gsp_yield = x[BatchKey.gsp]
# take the last value non forecaster value and the first in the pv yeild
# (this is the pv site we are preditcting for)
y_hat = gsp_yield[:, -self.forecast_len - 1]
# expand the last valid forward n predict steps
out = y_hat.unsqueeze(1).repeat(1, self.forecast_len)
return out