Skip to content

Commit 505d7d0

Browse files
Merge pull request #88 from openclimatefix/issue/85-perciever-nwp
Issue/85 perciever nwp
2 parents 294b093 + a8ec8be commit 505d7d0

20 files changed

+798
-13
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# @package _global_
2+
3+
# to execute this experiment run:
4+
# python run.py experiment=example_simple.yaml
5+
6+
defaults:
7+
- override /trainer: default.yaml # choose trainer from 'configs/trainer/'
8+
- override /model: perceiver_conv3d_sat_nwp.yaml
9+
- override /datamodule: netcdf_datamodule_gcp.yaml
10+
- override /callbacks: default.yaml
11+
- override /logger: neptune.yaml
12+
13+
# all parameters below will be merged with parameters from default configurations set above
14+
# this allows you to overwrite only specified parameters
15+
16+
seed: 518
17+
18+
trainer:
19+
min_epochs: 1
20+
max_epochs: 50
21+
22+
datamodule:
23+
n_train_data: 10000
24+
n_val_data: 1000
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# @package _global_
2+
3+
# to execute this experiment run:
4+
# python run.py experiment=example_simple.yaml
5+
6+
defaults:
7+
- override /trainer: default.yaml # choose trainer from 'configs/trainer/'
8+
- override /model: perceiver_sat_nwp.yaml
9+
- override /datamodule: netcdf_datamodule_gcp.yaml
10+
- override /callbacks: default.yaml
11+
- override /logger: neptune.yaml
12+
13+
# all parameters below will be merged with parameters from default configurations set above
14+
# this allows you to overwrite only specified parameters
15+
16+
seed: 518
17+
18+
trainer:
19+
min_epochs: 1
20+
max_epochs: 10
21+
22+
datamodule:
23+
n_train_data: 10000
24+
n_val_data: 1000

configs/model/perceiver.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ history_minutes: 60
55
batch_size: 8
66
num_latents: 128
77
latent_dim: 64
8-
embedding_dem: 16
8+
embedding_dem: 16
9+
output_variable: gsp_yield
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
_target_: predict_pv_yield.models.perceiver.perceiver_conv3d_nwp_sat.Model
2+
3+
forecast_minutes: 30
4+
history_minutes: 60
5+
batch_size: 32
6+
num_latents: 24
7+
latent_dim: 24
8+
embedding_dem: 0
9+
output_variable: gsp_yield
10+
conv3d_channels: 8
11+
use_future_satellite_images: 0

configs/model/perceiver_sat_nwp.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_target_: predict_pv_yield.models.perceiver.perceiver_nwp_sat.Model
2+
3+
forecast_minutes: 30
4+
history_minutes: 60
5+
batch_size: 8
6+
num_latents: 128
7+
latent_dim: 64
8+
embedding_dem: 0
9+
output_variable: gsp_yield
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Ran perciever RNN model
2+
3+
https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-245/charts
4+
5+
Includes validation images, so we can see how the model is perforaming after in epoch
6+
7+
due memory of gpu had to go
8+
9+
forecast_len: 12
10+
history_len: 6
11+
batch_size: 8
12+
num_latents: 32
13+
latent_dim: 32
14+
embedding_dem: 10

experiments/2021-09/2021-09-27/experiments.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ Using both sat and nwp into two separate convolution nets.
2222

2323
https://app.neptune.ai/OpenClimateFix/predict-pv-yield/e/PRED-320
2424

25-
# TODO Currently running
25+
MAE = 0.0376 - this was after 10 epochs, and I think it was still going down.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
1. Perceiver NWP SAT
2+
3+
https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-331/monitoring
4+
5+
Ran with
6+
- batch_size of 6, as GPU was out of memory
7+
- num_latents: int = 64,
8+
- latent_dim: int = 64,
9+
- embedding_dem: int = 0,
10+
11+
Each epoch takes about 3 hours
12+
13+
Decided to stop it earlier
14+
15+
1. Perceiver Conv3d NWP SAT
16+
17+
Idea is to have 1 conv3d + max pool later before the perceiver model
18+
https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-331/monitoring
19+
20+
Conv3d did not make much memory difference, the biggest being, changing the
21+
- num_latents
22+
- latent_dim
23+
24+
To get batch 32, set
25+
- num_latents = 16
26+
- latent_dim = 16
27+
- PERCEIVER_OUTPUT_SIZE = 512
28+
OR
29+
To get batch 32, set
30+
- num_latents = 24
31+
- latent_dim = 24
32+
- PERCEIVER_OUTPUT_SIZE = 128
33+
34+
https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-349/monitoring
35+
36+
~ 4 hours per epoch
37+
38+
MAE = 0.0308 (after 10 epochs)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
1. Perceiver Conv3d NWP SAT
2+
3+
No future satellite images
4+
5+
https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-378/charts
6+
7+
~ 4 hours per epoch
8+
9+
MAE = 0.0365 (after 22 epochs), compared to MAE 0.0304 when future satellite images were included

predict_pv_yield/data/dataloader.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
22
from nowcasting_dataset.dataset.datasets import NetCDFDataset, worker_init_fn
33
from nowcasting_dataset.dataset.validate import FakeDataset
4-
from nowcasting_dataset.config.model import Configuration
4+
from nowcasting_dataset.config.load import load_yaml_configuration
55
from typing import Tuple
66
import logging
77
import torch
88
from pytorch_lightning import LightningDataModule
99

1010

11+
1112
_LOG = logging.getLogger(__name__)
1213
_LOG.setLevel(logging.DEBUG)
1314

@@ -20,6 +21,8 @@ def get_dataloaders(
2021
data_path="prepared_ML_training_data/v4/",
2122
) -> Tuple:
2223

24+
configuration = load_yaml_configuration(filename=f'{data_path}/configuration.yaml')
25+
2326
data_module = NetCDFDataModule(
2427
temp_path=temp_path, data_path=data_path, cloud=cloud, n_train_data=n_train_data, n_val_data=n_validation_data
2528
)
@@ -70,6 +73,10 @@ def __init__(
7073
self.pin_memory = pin_memory
7174
self.fake_data = fake_data
7275

76+
filename = os.path.join(data_path, 'configuration.yaml')
77+
_LOG.debug(f'Will be loading the configuration file {filename}')
78+
self.configuration = load_yaml_configuration(filename=filename)
79+
7380
self.dataloader_config = dict(
7481
pin_memory=self.pin_memory,
7582
num_workers=self.num_workers,
@@ -83,40 +90,43 @@ def __init__(
8390

8491
def train_dataloader(self):
8592
if self.fake_data:
86-
train_dataset = FakeDataset(configuration=Configuration())
93+
train_dataset = FakeDataset(configuration=self.configuration)
8794
else:
8895
train_dataset = NetCDFDataset(
8996
self.n_train_data,
9097
os.path.join(self.data_path, "train"),
9198
os.path.join(self.temp_path, "train"),
9299
cloud=self.cloud,
100+
configuration=self.configuration
93101
)
94102

95103
return torch.utils.data.DataLoader(train_dataset, **self.dataloader_config)
96104

97105
def val_dataloader(self):
98106
if self.fake_data:
99-
val_dataset = FakeDataset(configuration=Configuration())
107+
val_dataset = FakeDataset(configuration=self.configuration)
100108
else:
101109
val_dataset = NetCDFDataset(
102110
self.n_val_data,
103111
os.path.join(self.data_path, "validation"),
104112
os.path.join(self.temp_path, "validation"),
105113
cloud=self.cloud,
114+
configuration=self.configuration
106115
)
107116

108117
return torch.utils.data.DataLoader(val_dataset, **self.dataloader_config)
109118

110119
def test_dataloader(self):
111120
if self.fake_data:
112-
test_dataset = FakeDataset(configuration=Configuration())
121+
test_dataset = FakeDataset(configuration=self.configuration)
113122
else:
114123
# TODO need to change this to a test folder
115124
test_dataset = NetCDFDataset(
116125
self.n_val_data,
117126
os.path.join(self.data_path, "validation"),
118127
os.path.join(self.temp_path, "validation"),
119128
cloud=self.cloud,
129+
configuration=self.configuration
120130
)
121131

122132
return torch.utils.data.DataLoader(test_dataset, **self.dataloader_config)

predict_pv_yield/models/base_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ def training_step(self, batch, batch_idx):
114114

115115
def validation_step(self, batch, batch_idx):
116116
INTERESTING_EXAMPLES = (1, 5, 6, 7, 9, 11, 17, 19)
117-
name = f"validation/plot/epoch{self.current_epoch}"
118-
if batch_idx == 0:
117+
name = f"validation/plot/epoch_{self.current_epoch}_{batch_idx}"
118+
if batch_idx in [0, 1, 2, 3, 4]:
119119

120120
# get model outputs
121121
model_output = self(batch)

predict_pv_yield/models/conv3d/model_sat_nwp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def __init__(
8080
* ((image_size_pixels - 2 * self.number_of_conv3d_layers) ** 2)
8181
* (self.forecast_len_5 + self.history_len_5 + 1 - 2 * self.number_of_conv3d_layers)
8282
)
83-
print(self.cnn_output_size)
8483

8584
# conv0
8685
self.sat_conv0 = nn.Conv3d(

predict_pv_yield/models/perceiver/perceiver.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,14 @@ def forward(self, x):
178178
dim=2,
179179
)
180180

181-
# take the history of the pv yield of this system,
182-
pv_yield_history = x["pv_yield"][0 : self.batch_size][:, : self.history_len_5 + 1, 0].unsqueeze(-1)
183-
encoder_input = torch.cat((rnn_input[:, : self.history_len_5 + 1], pv_yield_history), dim=2)
181+
if self.output_variable == 'pv_yield':
182+
# take the history of the pv yield of this system,
183+
pv_yield_history = x["pv_yield"][0: self.batch_size][:, : self.history_len_5 + 1, 0].unsqueeze(-1)
184+
encoder_input = torch.cat((rnn_input[:, : self.history_len_5 + 1], pv_yield_history), dim=2)
185+
elif self.output_variable == 'gsp_yield':
186+
# take the history of the gsp yield of this system,
187+
gsp_history = x[self.output_variable][0: self.batch_size][:, : self.history_len_30 + 1, 0].unsqueeze(-1)
188+
encoder_input = torch.cat((rnn_input[:, : self.history_len_30 + 1], gsp_history), dim=2)
184189

185190
encoder_output, encoder_hidden = self.encoder_rnn(encoder_input)
186191
decoder_output, _ = self.decoder_rnn(rnn_input[:, -self.forecast_len :], encoder_hidden)

0 commit comments

Comments
 (0)