Skip to content

Commit f5cd889

Browse files
authored
Merge pull request #179 from openclimatefix/jacob/sensor-attention
Add test data and fix encoder for multiple sites and sensors
2 parents c5b95b5 + c719c69 commit f5cd889

File tree

8 files changed

+184
-16
lines changed

8 files changed

+184
-16
lines changed

pvnet/data/wind_datamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class WindDataModule(BaseDataModule):
1313
def _get_datapipe(self, start_time, end_time):
1414
data_pipeline = windnet_netcdf_datapipe(
1515
self.configuration,
16-
keys=["wind", "nwp"],
16+
keys=["wind", "nwp", "sensor"],
1717
)
1818

1919
data_pipeline = (
@@ -26,7 +26,7 @@ def _get_datapipe(self, start_time, end_time):
2626
def _get_premade_batches_datapipe(self, subdir, shuffle=False):
2727
filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc"))
2828
data_pipeline = windnet_netcdf_datapipe(
29-
keys=["wind", "nwp"],
29+
keys=["wind", "nwp", "sensor"],
3030
filenames=filenames,
3131
)
3232
data_pipeline = (

pvnet/models/multimodal/site_encoders/encoders.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
"""
44

5+
import einops
56
import torch
67
from ocf_datapipes.batch import BatchKey
78
from torch import nn
@@ -128,6 +129,8 @@ def __init__(
128129
target_id_dim: int = 318,
129130
target_key_to_use: str = "gsp",
130131
input_key_to_use: str = "pv",
132+
num_channels: int = 1,
133+
num_sites_in_inference: int = 1,
131134
):
132135
"""A simple attention-based model with a single multihead attention layer
133136
@@ -148,6 +151,13 @@ def __init__(
148151
target_id_dim: The number of unique IDs.
149152
target_key_to_use: The key to use for the target in the attention layer.
150153
input_key_to_use: The key to use for the input in the attention layer.
154+
num_channels: Number of channels in the input data. For single site generation,
155+
this will be 1, as there is not channel dimension, for Sensors,
156+
this will probably be higher than that
157+
num_sites_in_inference: Number of sites to use in inference.
158+
This is used to determine the number of sites to use in the
159+
attention layer, for a single site, 1 works, while for multiple sites
160+
(such as multiple sensors), this would be higher than that
151161
152162
"""
153163
super().__init__(sequence_length, num_sites, out_features)
@@ -158,15 +168,18 @@ def __init__(
158168
self.use_id_in_value = use_id_in_value
159169
self.target_key_to_use = target_key_to_use
160170
self.input_key_to_use = input_key_to_use
171+
self.num_channels = num_channels
172+
self.num_sites_in_inference = num_sites_in_inference
161173

162174
if use_id_in_value:
163175
self.value_id_embedding = nn.Embedding(num_sites, id_embed_dim)
164176

165177
self._value_encoder = nn.Sequential(
166178
ResFCNet2(
167-
in_features=sequence_length + int(use_id_in_value) * id_embed_dim,
179+
in_features=sequence_length * self.num_channels
180+
+ int(use_id_in_value) * id_embed_dim,
168181
out_features=out_features,
169-
fc_hidden_features=sequence_length,
182+
fc_hidden_features=sequence_length * self.num_channels,
170183
n_res_blocks=n_kv_res_blocks,
171184
res_block_layers=kv_res_block_layers,
172185
dropout_frac=0,
@@ -175,9 +188,9 @@ def __init__(
175188

176189
self._key_encoder = nn.Sequential(
177190
ResFCNet2(
178-
in_features=sequence_length + id_embed_dim,
191+
in_features=id_embed_dim + sequence_length * self.num_channels,
179192
out_features=kdim,
180-
fc_hidden_features=id_embed_dim + sequence_length,
193+
fc_hidden_features=id_embed_dim + sequence_length * self.num_channels,
181194
n_res_blocks=n_kv_res_blocks,
182195
res_block_layers=kv_res_block_layers,
183196
dropout_frac=0,
@@ -192,6 +205,20 @@ def __init__(
192205
batch_first=True,
193206
)
194207

208+
def _encode_inputs(self, x):
209+
# Shape: [batch size, sequence length, PV site] -> [8, 197, 1]
210+
# Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23]
211+
input_data = x[BatchKey[f"{self.input_key_to_use}"]]
212+
if len(input_data.shape) == 4: # Has multiple channels
213+
input_data = input_data[:, :, : self.sequence_length]
214+
input_data = einops.rearrange(input_data, "b id s c -> b (s c) id")
215+
else:
216+
input_data = input_data[:, : self.sequence_length]
217+
site_seqs = input_data.float()
218+
batch_size = site_seqs.shape[0]
219+
site_seqs = site_seqs.swapaxes(1, 2) # [batch size, Site ID, sequence length]
220+
return site_seqs, batch_size
221+
195222
def _encode_query(self, x):
196223
# Select the first one
197224
if self.target_key_to_use == "gsp":
@@ -206,34 +233,29 @@ def _encode_query(self, x):
206233
return query
207234

208235
def _encode_key(self, x):
209-
# Shape: [batch size, sequence length, PV site]
210-
site_seqs = x[BatchKey[f"{self.input_key_to_use}"]][:, : self.sequence_length].float()
211-
batch_size = site_seqs.shape[0]
236+
site_seqs, batch_size = self._encode_inputs(x)
212237

213238
# wind ID embeddings are the same for each sample
214239
site_id_embed = torch.tile(self.site_id_embedding(self._ids), (batch_size, 1, 1))
215240
# Each concated (wind sequence, wind ID embedding) is processed with encoder
216-
x_seq_in = torch.cat((site_seqs.swapaxes(1, 2), site_id_embed), dim=2).flatten(0, 1)
241+
x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
217242
key = self._key_encoder(x_seq_in)
218243

219244
# Reshape to [batch size, PV site, kdim]
220245
key = key.unflatten(0, (batch_size, self.num_sites))
221246
return key
222247

223248
def _encode_value(self, x):
224-
# Shape: [batch size, sequence length, PV site]
225-
site_seqs = x[BatchKey[f"{self.input_key_to_use}"]][:, : self.sequence_length].float()
226-
batch_size = site_seqs.shape[0]
249+
site_seqs, batch_size = self._encode_inputs(x)
227250

228251
if self.use_id_in_value:
229252
# wind ID embeddings are the same for each sample
230253
site_id_embed = torch.tile(self.value_id_embedding(self._ids), (batch_size, 1, 1))
231254
# Each concated (wind sequence, wind ID embedding) is processed with encoder
232-
x_seq_in = torch.cat((site_seqs.swapaxes(1, 2), site_id_embed), dim=2).flatten(0, 1)
255+
x_seq_in = torch.cat((site_seqs, site_id_embed), dim=2).flatten(0, 1)
233256
else:
234257
# Encode each PV sequence independently
235-
x_seq_in = site_seqs.swapaxes(1, 2).flatten(0, 1)
236-
258+
x_seq_in = site_seqs.flatten(0, 1)
237259
value = self._value_encoder(x_seq_in)
238260

239261
# Reshape to [batch size, PV site, vdim]

tests/conftest.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pvnet
1616
from pvnet.data.datamodule import DataModule
17+
from pvnet.data.wind_datamodule import WindDataModule
1718

1819
import pvnet.models.multimodal.encoders.encoders3d
1920
import pvnet.models.multimodal.linear_networks.networks
@@ -161,6 +162,22 @@ def sample_pv_batch(sample_batch):
161162
return pv_data
162163

163164

165+
@pytest.fixture()
166+
def sample_wind_batch():
167+
dm = WindDataModule(
168+
configuration=None,
169+
batch_size=2,
170+
num_workers=0,
171+
prefetch_factor=None,
172+
train_period=[None, None],
173+
val_period=[None, None],
174+
test_period=[None, None],
175+
batch_dir="tests/test_data/sample_wind_batches",
176+
)
177+
batch = next(iter(dm.train_dataloader()))
178+
return batch
179+
180+
164181
@pytest.fixture()
165182
def model_minutes_kwargs():
166183
kwargs = dict(
@@ -193,6 +210,20 @@ def site_encoder_model_kwargs():
193210
return kwargs
194211

195212

213+
@pytest.fixture()
214+
def site_encoder_sensor_model_kwargs():
215+
# Used to test site encoder model on PV data
216+
kwargs = dict(
217+
sequence_length=180 // 5 + 1,
218+
num_sites=26,
219+
out_features=128,
220+
num_channels=23,
221+
target_key_to_use="wind",
222+
input_key_to_use="sensor",
223+
)
224+
return kwargs
225+
226+
196227
@pytest.fixture()
197228
def raw_multimodal_model_kwargs(model_minutes_kwargs):
198229
kwargs = dict(

tests/models/multimodal/site_encoders/test_encoders.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def test_singleattentionnetwork_forward(sample_batch, site_encoder_model_kwargs)
3232
_test_model_forward(sample_batch, SingleAttentionNetwork, site_encoder_model_kwargs)
3333

3434

35+
def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs):
36+
_test_model_forward(sample_wind_batch, SingleAttentionNetwork, site_encoder_sensor_model_kwargs)
37+
38+
3539
# Test model backward on all models
3640
def test_simplelearnedaggregator_backward(sample_batch, site_encoder_model_kwargs):
3741
_test_model_backward(sample_batch, SimpleLearnedAggregator, site_encoder_model_kwargs)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
general:
2+
description: WindNet configuration for Leonardo
3+
name: windnet_india
4+
5+
input_data:
6+
default_forecast_minutes: 2880
7+
default_history_minutes: 60
8+
data_source_which_defines_geospatial_locations: "wind"
9+
nwp:
10+
ecmwf:
11+
# Path to ECMWF NWP data in zarr format
12+
# n.b. It is not necessary to use multiple or any NWP data. These entries can be removed
13+
nwp_zarr_path: "/mnt/storage_b/data/ocf/solar_pv_nowcasting/nowcasting_dataset_pipeline/NWP/ECMWF/nw-india/zarr/20*.zarr.zip"
14+
history_minutes: 60
15+
forecast_minutes: 2880
16+
time_resolution_minutes: 60
17+
nwp_channels:
18+
#- hcc
19+
#- lcc
20+
#- mcc
21+
#- prate
22+
#- sde
23+
#- sr
24+
- t2m
25+
#- tcc
26+
- u10
27+
- u100
28+
- u200
29+
- v10
30+
- v100
31+
- v200
32+
nwp_image_size_pixels_height: 168 # roughtly equivalent to ukv 24 pixels
33+
nwp_image_size_pixels_width: 168
34+
x_dim_name: "longitude"
35+
y_dim_name: "latitude"
36+
nwp_provider: "ecmwf"
37+
dropout_timedeltas_minutes: [-360] # 6 hours
38+
# Dropout applied with this probability
39+
dropout_fraction: 1.0
40+
#start_datetime: "2021-01-01 00:00:00"
41+
#end_datetime: "2024-01-01 00:00:00"
42+
# excarta:
43+
# nwp_zarr_path: "/mnt/storage_b/nwp/excarta/hindcast.zarr"
44+
# history_minutes: 60
45+
# forecast_minutes: 2160 # 48 hours won't work much, as its only midnight ones, maybe 24 hours to ensure more coverage
46+
# time_resolution_minutes: 60
47+
# nwp_channels:
48+
# - 10u
49+
# - 100u
50+
# - 10v
51+
# - 100v
52+
# - surface_pressure
53+
# #- mean_sea_level_pressure
54+
# nwp_image_size_pixels_height: 64 # roughtly equivalent to ukv 24 pixels
55+
# nwp_image_size_pixels_width: 64
56+
# nwp_provider: "excarta"
57+
# x_dim_name: "longitude"
58+
# y_dim_name: "latitude"
59+
# dropout_timedeltas_minutes: [ -360 ] # 6 hours
60+
# # Dropout applied with this probability
61+
# dropout_fraction: 1.0
62+
wind:
63+
wind_files_groups:
64+
- label: india
65+
wind_filename: /mnt/storage_ssd_4tb/india_wind_data.nc
66+
wind_metadata_filename: /mnt/storage_ssd_4tb/india_wind_metadata.csv
67+
get_center: true
68+
n_wind_systems_per_example: 1
69+
#start_datetime: "2021-01-01 00:00:00"
70+
#end_datetime: "2024-01-01 00:00:00"
71+
sensor:
72+
#sensor_files_groups:
73+
# - label: meteomatics
74+
sensor_filename: "/mnt/storage_b/nwp/meteomatics/nw_india/wind*.zarr.zip"
75+
get_center: false
76+
history_minutes: 60
77+
forecast_minutes: 2880
78+
#n_sensor_systems_per_example: 26
79+
time_resolution_minutes: 15
80+
#x_dim_name: "lon"
81+
#y_dim_name: "lat"
82+
sensor_variables:
83+
- 100u
84+
- 100v
85+
- 10u
86+
- 10v
87+
- 200u
88+
- 200v
89+
- cape:Jkg
90+
- air_density_25m:kgm3
91+
- air_density_10m:kgm3
92+
- air_density_100m:kgm3
93+
- air_density_200m:kgm3
94+
- wind_gusts_200m:ms
95+
- wind_gusts_100m:ms
96+
- wind_gusts_10m:ms
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
_target_: pvnet.data.wind_datamodule.WindDataModule
2+
configuration: /home/jacob/PVNet/configs/datamodule/configuration/leonardo_wind_configuration.yaml
3+
num_workers: 1
4+
prefetch_factor: 2
5+
batch_size: 8
6+
batch_dir: /mnt/storage_a/windnet_india_batches_large_meteomatics
7+
train_period:
8+
- "2019-01-01"
9+
- "2022-11-29"
10+
val_period:
11+
- "2022-12-01"
12+
- "2023-12-31"
13+
test_period:
14+
- "2023-09-01"
15+
- "2023-12-31"
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)