-
-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathsingle_value.py
38 lines (30 loc) · 1.21 KB
/
single_value.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Average value model"""
import torch
from ocf_datapipes.batch import BatchKey
from torch import nn
import pvnet
from pvnet.models.base_model import BaseModel
from pvnet.optimizers import AbstractOptimizer
class Model(BaseModel):
"""Simple baseline model that predicts always the same value."""
name = "single_value"
def __init__(
self,
forecast_minutes: int = 120,
history_minutes: int = 60,
optimizer: AbstractOptimizer = pvnet.optimizers.Adam(),
):
"""Simple baseline model that predicts always the same value.
Args:
history_minutes (int): Length of the GSP history period in minutes
forecast_minutes (int): Length of the GSP forecast period in minutes
optimizer (AbstractOptimizer): Optimizer
"""
super().__init__(history_minutes, forecast_minutes, optimizer)
self._value = nn.Parameter(torch.zeros(1), requires_grad=True)
self.save_hyperparameters()
def forward(self, x: dict):
"""Run model forward on dict batch of data"""
# Returns a single value at all steps
y_hat = torch.zeros_like(x[BatchKey.gsp][:, : self.forecast_len, 0]) + self._value
return y_hat