Skip to content

Commit 71fbdb0

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4b7d85d commit 71fbdb0

36 files changed

+78
-90
lines changed

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
1818
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
1919
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
2020
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21-
SOFTWARE.
21+
SOFTWARE.

README.md

-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,3 @@ pip install -e .
1616
```
1717

1818
Also download `nowcasting_dataset` and install in the `predict_pv_yield` conda environment using `pip install -e .` from within the `nowcasting_dataset` directory.
19-

configs/experiment/conv3d.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ datamodule:
2424
n_val_data: 1000
2525

2626
model:
27-
conv3d_channels: 32
27+
conv3d_channels: 32

configs/experiment/conv3d_sat_nwp.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ datamodule:
2424
n_val_data: 1000
2525

2626
model:
27-
conv3d_channels: 32
27+
conv3d_channels: 32

configs/experiment/example_simple.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ datamodule:
2424
n_val_data: 2
2525
fake_data: 1
2626

27-
validate_only: '1' # by putting this key in the config file, the model does not get trained.
27+
validate_only: '1' # by putting this key in the config file, the model does not get trained.

configs/hparams_search/conv3d_optuna.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,3 @@ hydra:
5353
model.fc3_output_features:
5454
type: categorical
5555
choices: [8, 16, 32, 64, 128 ]
56-

configs/model/conv3d.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ conv3d_channels: 32
1212
fc1_output_features: 128
1313
fc2_output_features: 128
1414
fc3_output_features: 64
15-
output_variable: gsp_yield
15+
output_variable: gsp_yield

configs/model/conv3d_sat_nwp.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ conv3d_channels: 32
1212
fc1_output_features: 128
1313
fc2_output_features: 128
1414
fc3_output_features: 64
15-
output_variable: gsp_yield
15+
output_variable: gsp_yield

configs/model/perceiver.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ batch_size: 8
66
num_latents: 128
77
latent_dim: 64
88
embedding_dem: 16
9-
output_variable: gsp_yield
9+
output_variable: gsp_yield

configs/model/perceiver_conv3d_sat_nwp.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ latent_dim: 24
88
embedding_dem: 0
99
output_variable: gsp_yield
1010
conv3d_channels: 8
11-
use_future_satellite_images: 0
11+
use_future_satellite_images: 0

configs/model/perceiver_sat_nwp.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ batch_size: 8
66
num_latents: 128
77
latent_dim: 64
88
embedding_dem: 0
9-
output_variable: gsp_yield
9+
output_variable: gsp_yield

configs/readme.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
The following folders how the configuration files
22

3-
This idea is copied from
3+
This idea is copied from
44
https://github.com/ashleve/lightning-hydra-template/blob/main/configs/experiment/example_simple.yaml
55

66
run experiments by:
77
`python run.py experiment=example_simple `
8-

environment.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ channels:
55
dependencies:
66
- python>=3.9
77
- pip
8-
8+
99
# Scientific Python
1010
- numpy
1111
- pandas
@@ -14,10 +14,10 @@ dependencies:
1414
- ipykernel
1515
- pyproj
1616
- h5netcdf
17-
17+
1818
# Cloud & distributed compute
1919
- gcsfs
20-
20+
2121
# Machine learning
2222
- pytorch::pytorch # explicitly specify pytorch channel to prevent conda from using conda-forge for pytorch, and hence installing the CPU-only version.
2323
- pytorch-lightning
@@ -27,7 +27,7 @@ dependencies:
2727
- flake8
2828
- jedi
2929
- mypy
30-
30+
3131
- pip:
3232
- neptune-client[pytorch-lightning]
3333
- tilemapbase # For plotting human-readable geographical maps.

experiments/001_CNN_concat_all_timesteps_as_channels.py

+39-39
Original file line numberDiff line numberDiff line change
@@ -111,21 +111,21 @@ def plot_example(batch, model_output, example_i: int=0, border: int=0):
111111
fig = plt.figure(figsize=(20, 20))
112112
ncols=4
113113
nrows=2
114-
114+
115115
# Satellite data
116116
extent = (
117-
float(batch['sat_x_coords'][example_i, 0].cpu().numpy()),
118-
float(batch['sat_x_coords'][example_i, -1].cpu().numpy()),
119-
float(batch['sat_y_coords'][example_i, -1].cpu().numpy()),
117+
float(batch['sat_x_coords'][example_i, 0].cpu().numpy()),
118+
float(batch['sat_x_coords'][example_i, -1].cpu().numpy()),
119+
float(batch['sat_y_coords'][example_i, -1].cpu().numpy()),
120120
float(batch['sat_y_coords'][example_i, 0].cpu().numpy())) # left, right, bottom, top
121-
121+
122122
def _format_ax(ax):
123123
#ax.set_xlim(extent[0]-border, extent[1]+border)
124124
#ax.set_ylim(extent[2]-border, extent[3]+border)
125125
# ax.coastlines(color='black')
126126
ax.scatter(
127-
batch['x_meters_center'][example_i].cpu(),
128-
batch['y_meters_center'][example_i].cpu(),
127+
batch['x_meters_center'][example_i].cpu(),
128+
batch['y_meters_center'][example_i].cpu(),
129129
s=500, color='white', marker='x')
130130

131131
ax = fig.add_subplot(nrows, ncols, 1) #, projection=ccrs.OSGB(approx=False))
@@ -140,12 +140,12 @@ def _format_ax(ax):
140140
ax.imshow(sat_data[params['history_len']+1], extent=extent, interpolation='none', vmin=sat_min, vmax=sat_max)
141141
ax.set_title('t = 0')
142142
_format_ax(ax)
143-
143+
144144
ax = fig.add_subplot(nrows, ncols, 3)
145145
ax.imshow(sat_data[-1], extent=extent, interpolation='none', vmin=sat_min, vmax=sat_max)
146146
ax.set_title('t = {}'.format(params['forecast_len']))
147147
_format_ax(ax)
148-
148+
149149
ax = fig.add_subplot(nrows, ncols, 4)
150150
lat_lon_bottom_left = osgb_to_lat_lon(extent[0], extent[2])
151151
lat_lon_top_right = osgb_to_lat_lon(extent[1], extent[3])
@@ -163,7 +163,7 @@ def _format_ax(ax):
163163
ax = fig.add_subplot(nrows, ncols, 5)
164164
nwp_dt_index = pd.to_datetime(batch['nwp_target_time'][example_i].cpu().numpy(), unit='s')
165165
pd.DataFrame(
166-
batch['nwp'][example_i, :, :, 0, 0].T.cpu().numpy(),
166+
batch['nwp'][example_i, :, :, 0, 0].T.cpu().numpy(),
167167
index=nwp_dt_index,
168168
columns=params['nwp_channels']).plot(ax=ax)
169169
ax.set_title('NWP')
@@ -194,14 +194,14 @@ def _format_ax(ax):
194194
ax.legend()
195195

196196
# fig.tight_layout()
197-
197+
198198
return fig
199199

200200

201201
# In[11]:
202202

203203

204-
# plot_example(batch, model_output, example_i=20);
204+
# plot_example(batch, model_output, example_i=20);
205205

206206

207207
# In[12]:
@@ -234,89 +234,89 @@ def __init__(
234234
self,
235235
history_len = params['history_len'],
236236
forecast_len = params['forecast_len'],
237-
237+
238238
):
239239
super().__init__()
240240
self.history_len = history_len
241241
self.forecast_len = forecast_len
242-
242+
243243
self.sat_conv1 = nn.Conv2d(in_channels=history_len+6, out_channels=CHANNELS, kernel_size=KERNEL)#, groups=history_len+1)
244244
self.sat_conv2 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL) #, groups=CHANNELS//2)
245245
self.sat_conv3 = nn.Conv2d(in_channels=CHANNELS, out_channels=CHANNELS, kernel_size=KERNEL) #, groups=CHANNELS)
246246

247247
self.maxpool = nn.MaxPool2d(kernel_size=KERNEL)
248-
248+
249249
self.fc1 = nn.Linear(
250-
in_features=CHANNELS * 11 * 11,
250+
in_features=CHANNELS * 11 * 11,
251251
out_features=256)
252-
252+
253253
self.fc2 = nn.Linear(in_features=256 + EMBEDDING_DIM + NWP_SIZE + N_DATETIME_FEATURES + history_len+1, out_features=128)
254254
#self.fc2 = nn.Linear(in_features=EMBEDDING_DIM + N_DATETIME_FEATURES, out_features=128)
255255
self.fc3 = nn.Linear(in_features=128, out_features=128)
256256
self.fc4 = nn.Linear(in_features=128, out_features=128)
257257
self.fc5 = nn.Linear(in_features=128, out_features=params['forecast_len'])
258-
258+
259259
if EMBEDDING_DIM:
260260
self.pv_system_id_embedding = nn.Embedding(
261261
num_embeddings=len(data_module.pv_data_source.pv_metadata),
262262
embedding_dim=EMBEDDING_DIM)
263-
263+
264264
def forward(self, x):
265265
# ******************* Satellite imagery *************************
266266
# Shape: batch_size, seq_length, width, height, channel
267267
sat_data = x['sat_data'][:, :self.history_len+1]
268268
batch_size, seq_len, width, height, n_chans = sat_data.shape
269-
269+
270270
# Move seq_length to be the last dim, ready for changing the shape
271271
sat_data = sat_data.permute(0, 2, 3, 4, 1)
272-
272+
273273
# Stack timesteps into the channel dimension
274274
sat_data = sat_data.view(batch_size, width, height, seq_len * n_chans)
275-
275+
276276
sat_data = sat_data.permute(0, 3, 1, 2) # Conv2d expects channels to be the 2nd dim!
277-
277+
278278
### EXTRA CHANNELS
279279
# Center marker
280280
center_marker = torch.zeros((batch_size, 1, width, height), dtype=torch.float32, device=self.device)
281281
half_width = width // 2
282282
center_marker[..., half_width-2:half_width+2, half_width-2:half_width+2] = 1
283-
283+
284284
# geo-spatial x
285285
x_coords = x['sat_x_coords'] - SAT_X_MEAN
286286
x_coords /= SAT_X_STD
287287
x_coords = x_coords.unsqueeze(1).expand(-1, width, -1).unsqueeze(1)
288-
288+
289289
# geo-spatial y
290290
y_coords = x['sat_y_coords'] - SAT_Y_MEAN
291291
y_coords /= SAT_Y_STD
292292
y_coords = y_coords.unsqueeze(-1).expand(-1, -1, height).unsqueeze(1)
293-
293+
294294
# pixel x & y
295295
pixel_range = (torch.arange(width, device=self.device) - 64) / 37
296296
pixel_range = pixel_range.unsqueeze(0).unsqueeze(0)
297297
pixel_x = pixel_range.unsqueeze(-2).expand(batch_size, 1, width, -1)
298298
pixel_y = pixel_range.unsqueeze(-1).expand(batch_size, 1, -1, height)
299-
299+
300300
# Concat
301301
sat_data = torch.cat((sat_data, center_marker, x_coords, y_coords, pixel_x, pixel_y), dim=1)
302-
302+
303303
del center_marker, x_coords, y_coords, pixel_x, pixel_y
304-
304+
305305
# Pass data through the network :)
306306
out = F.relu(self.sat_conv1(sat_data))
307307
out = self.maxpool(out)
308308
out = F.relu(self.sat_conv2(out))
309309
out = self.maxpool(out)
310310
out = F.relu(self.sat_conv3(out))
311-
311+
312312
out = out.view(-1, CHANNELS * 11 * 11)
313313
out = F.relu(self.fc1(out))
314-
314+
315315
# *********************** NWP Data **************************************
316316
nwp_data = x['nwp'].float() # Shape: batch_size, channel, seq_length, width, height
317317
batch_size, n_nwp_chans, nwp_seq_len, nwp_width, nwp_height = nwp_data.shape
318318
nwp_data = nwp_data.reshape(batch_size, n_nwp_chans * nwp_seq_len * nwp_width * nwp_height)
319-
319+
320320
# Concat
321321
out = torch.cat(
322322
(
@@ -330,15 +330,15 @@ def forward(self, x):
330330
),
331331
dim=1)
332332
del nwp_data
333-
333+
334334
# Embedding of PV system ID
335335
if EMBEDDING_DIM:
336336
pv_embedding = self.pv_system_id_embedding(x['pv_system_row_number'])
337337
out = torch.cat(
338338
(
339339
out,
340340
pv_embedding
341-
),
341+
),
342342
dim=1)
343343

344344
# Fully connected layers.
@@ -348,7 +348,7 @@ def forward(self, x):
348348
out = F.relu(self.fc5(out)) # PV yield is in range [0, 1]. ReLU should train more cleanly than sigmoid.
349349

350350
return out
351-
351+
352352
def _training_or_validation_step(self, batch, is_train_step):
353353
y_hat = self(batch)
354354
y = batch['pv_yield'][:, -self.forecast_len:]
@@ -360,19 +360,19 @@ def _training_or_validation_step(self, batch, is_train_step):
360360
tag = "Train" if is_train_step else "Validation"
361361
self.log_dict({f'MSE/{tag}': mse_loss}, on_step=is_train_step, on_epoch=True)
362362
self.log_dict({f'NMAE/{tag}': nmae_loss}, on_step=is_train_step, on_epoch=True)
363-
363+
364364
return nmae_loss
365365

366366
def training_step(self, batch, batch_idx):
367367
return self._training_or_validation_step(batch, is_train_step=True)
368-
368+
369369
def validation_step(self, batch, batch_idx):
370370
if batch_idx == 0:
371371
# Plot example
372372
model_output = self(batch)
373373
fig = plot_example(batch, model_output)
374374
self.logger.experiment['validation/plot'].log(File.as_image(fig))
375-
375+
376376
return self._training_or_validation_step(batch, is_train_step=False)
377377

378378
def configure_optimizers(self):
@@ -436,4 +436,4 @@ def configure_optimizers(self):
436436
trainer.fit(model, data_module)
437437

438438

439-
# In[ ]:
439+
# In[ ]:

experiments/2021-08/2021-08-17/run_cnn3d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ def main():
2929
trainer.fit(model, train_dataloader)
3030

3131
if __name__ == '__main__':
32-
main()
32+
main()

experiments/2021-08/2021-08-18/run_cnn3d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ def main():
3434
# 1. Large training set, and one epoch took a day, so should use GPU for this model. I was a bit suprised as I didnt
3535
# think the model was so big.
3636
# 2. Need to work on validationm general validation method. Good to base line against a really simple model. For
37-
# validation might need to think carefully about metrics that will be used.
37+
# validation might need to think carefully about metrics that will be used.

experiments/2021-08/2021-08-24/run_cnn3d_n_layers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,4 @@ def main():
4141
# https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-133/monitoring
4242

4343
# https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-132/monitoring
44-
# https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-131/monitoring
44+
# https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-131/monitoring

experiments/2021-08/2021-08-26/run_cnn3d_n_layers.py

-1
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,3 @@ def main():
4343

4444
# https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-138/charts
4545
# ran with 10000 in train data
46-

experiments/2021-08/2021-08-27/experiments.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Daily Experiments
22

3-
Ran hydra for the first time, for hyper parameters optermization.
3+
Ran hydra for the first time, for hyper parameters optermization.
44
It did 2 full runs, then I think ran out of memory caused a funny error.
5-
Now have install 'psutil' so that cpu and memory is logged to neptune.
5+
Now have install 'psutil' so that cpu and memory is logged to neptune.
66

77
https://app.neptune.ai/o/OpenClimateFix/org/predict-pv-yield/e/PRED-160/monitoring
88
Validation error after 10 epochs - 0.073
@@ -26,4 +26,4 @@ Validation error after 2 epochs - 0.076 (then error happened in 3rd epoch)
2626
conv3d_channels = 32
2727
fc1_output_features = 64
2828
fc2_output_features = 16
29-
fc3_output_features = 8
29+
fc3_output_features = 8

0 commit comments

Comments
 (0)