@@ -111,21 +111,21 @@ def plot_example(batch, model_output, example_i: int=0, border: int=0):
111
111
fig = plt .figure (figsize = (20 , 20 ))
112
112
ncols = 4
113
113
nrows = 2
114
-
114
+
115
115
# Satellite data
116
116
extent = (
117
- float (batch ['sat_x_coords' ][example_i , 0 ].cpu ().numpy ()),
118
- float (batch ['sat_x_coords' ][example_i , - 1 ].cpu ().numpy ()),
119
- float (batch ['sat_y_coords' ][example_i , - 1 ].cpu ().numpy ()),
117
+ float (batch ['sat_x_coords' ][example_i , 0 ].cpu ().numpy ()),
118
+ float (batch ['sat_x_coords' ][example_i , - 1 ].cpu ().numpy ()),
119
+ float (batch ['sat_y_coords' ][example_i , - 1 ].cpu ().numpy ()),
120
120
float (batch ['sat_y_coords' ][example_i , 0 ].cpu ().numpy ())) # left, right, bottom, top
121
-
121
+
122
122
def _format_ax (ax ):
123
123
#ax.set_xlim(extent[0]-border, extent[1]+border)
124
124
#ax.set_ylim(extent[2]-border, extent[3]+border)
125
125
# ax.coastlines(color='black')
126
126
ax .scatter (
127
- batch ['x_meters_center' ][example_i ].cpu (),
128
- batch ['y_meters_center' ][example_i ].cpu (),
127
+ batch ['x_meters_center' ][example_i ].cpu (),
128
+ batch ['y_meters_center' ][example_i ].cpu (),
129
129
s = 500 , color = 'white' , marker = 'x' )
130
130
131
131
ax = fig .add_subplot (nrows , ncols , 1 ) #, projection=ccrs.OSGB(approx=False))
@@ -140,12 +140,12 @@ def _format_ax(ax):
140
140
ax .imshow (sat_data [params ['history_len' ]+ 1 ], extent = extent , interpolation = 'none' , vmin = sat_min , vmax = sat_max )
141
141
ax .set_title ('t = 0' )
142
142
_format_ax (ax )
143
-
143
+
144
144
ax = fig .add_subplot (nrows , ncols , 3 )
145
145
ax .imshow (sat_data [- 1 ], extent = extent , interpolation = 'none' , vmin = sat_min , vmax = sat_max )
146
146
ax .set_title ('t = {}' .format (params ['forecast_len' ]))
147
147
_format_ax (ax )
148
-
148
+
149
149
ax = fig .add_subplot (nrows , ncols , 4 )
150
150
lat_lon_bottom_left = osgb_to_lat_lon (extent [0 ], extent [2 ])
151
151
lat_lon_top_right = osgb_to_lat_lon (extent [1 ], extent [3 ])
@@ -163,7 +163,7 @@ def _format_ax(ax):
163
163
ax = fig .add_subplot (nrows , ncols , 5 )
164
164
nwp_dt_index = pd .to_datetime (batch ['nwp_target_time' ][example_i ].cpu ().numpy (), unit = 's' )
165
165
pd .DataFrame (
166
- batch ['nwp' ][example_i , :, :, 0 , 0 ].T .cpu ().numpy (),
166
+ batch ['nwp' ][example_i , :, :, 0 , 0 ].T .cpu ().numpy (),
167
167
index = nwp_dt_index ,
168
168
columns = params ['nwp_channels' ]).plot (ax = ax )
169
169
ax .set_title ('NWP' )
@@ -194,14 +194,14 @@ def _format_ax(ax):
194
194
ax .legend ()
195
195
196
196
# fig.tight_layout()
197
-
197
+
198
198
return fig
199
199
200
200
201
201
# In[11]:
202
202
203
203
204
- # plot_example(batch, model_output, example_i=20);
204
+ # plot_example(batch, model_output, example_i=20);
205
205
206
206
207
207
# In[12]:
@@ -234,89 +234,89 @@ def __init__(
234
234
self ,
235
235
history_len = params ['history_len' ],
236
236
forecast_len = params ['forecast_len' ],
237
-
237
+
238
238
):
239
239
super ().__init__ ()
240
240
self .history_len = history_len
241
241
self .forecast_len = forecast_len
242
-
242
+
243
243
self .sat_conv1 = nn .Conv2d (in_channels = history_len + 6 , out_channels = CHANNELS , kernel_size = KERNEL )#, groups=history_len+1)
244
244
self .sat_conv2 = nn .Conv2d (in_channels = CHANNELS , out_channels = CHANNELS , kernel_size = KERNEL ) #, groups=CHANNELS//2)
245
245
self .sat_conv3 = nn .Conv2d (in_channels = CHANNELS , out_channels = CHANNELS , kernel_size = KERNEL ) #, groups=CHANNELS)
246
246
247
247
self .maxpool = nn .MaxPool2d (kernel_size = KERNEL )
248
-
248
+
249
249
self .fc1 = nn .Linear (
250
- in_features = CHANNELS * 11 * 11 ,
250
+ in_features = CHANNELS * 11 * 11 ,
251
251
out_features = 256 )
252
-
252
+
253
253
self .fc2 = nn .Linear (in_features = 256 + EMBEDDING_DIM + NWP_SIZE + N_DATETIME_FEATURES + history_len + 1 , out_features = 128 )
254
254
#self.fc2 = nn.Linear(in_features=EMBEDDING_DIM + N_DATETIME_FEATURES, out_features=128)
255
255
self .fc3 = nn .Linear (in_features = 128 , out_features = 128 )
256
256
self .fc4 = nn .Linear (in_features = 128 , out_features = 128 )
257
257
self .fc5 = nn .Linear (in_features = 128 , out_features = params ['forecast_len' ])
258
-
258
+
259
259
if EMBEDDING_DIM :
260
260
self .pv_system_id_embedding = nn .Embedding (
261
261
num_embeddings = len (data_module .pv_data_source .pv_metadata ),
262
262
embedding_dim = EMBEDDING_DIM )
263
-
263
+
264
264
def forward (self , x ):
265
265
# ******************* Satellite imagery *************************
266
266
# Shape: batch_size, seq_length, width, height, channel
267
267
sat_data = x ['sat_data' ][:, :self .history_len + 1 ]
268
268
batch_size , seq_len , width , height , n_chans = sat_data .shape
269
-
269
+
270
270
# Move seq_length to be the last dim, ready for changing the shape
271
271
sat_data = sat_data .permute (0 , 2 , 3 , 4 , 1 )
272
-
272
+
273
273
# Stack timesteps into the channel dimension
274
274
sat_data = sat_data .view (batch_size , width , height , seq_len * n_chans )
275
-
275
+
276
276
sat_data = sat_data .permute (0 , 3 , 1 , 2 ) # Conv2d expects channels to be the 2nd dim!
277
-
277
+
278
278
### EXTRA CHANNELS
279
279
# Center marker
280
280
center_marker = torch .zeros ((batch_size , 1 , width , height ), dtype = torch .float32 , device = self .device )
281
281
half_width = width // 2
282
282
center_marker [..., half_width - 2 :half_width + 2 , half_width - 2 :half_width + 2 ] = 1
283
-
283
+
284
284
# geo-spatial x
285
285
x_coords = x ['sat_x_coords' ] - SAT_X_MEAN
286
286
x_coords /= SAT_X_STD
287
287
x_coords = x_coords .unsqueeze (1 ).expand (- 1 , width , - 1 ).unsqueeze (1 )
288
-
288
+
289
289
# geo-spatial y
290
290
y_coords = x ['sat_y_coords' ] - SAT_Y_MEAN
291
291
y_coords /= SAT_Y_STD
292
292
y_coords = y_coords .unsqueeze (- 1 ).expand (- 1 , - 1 , height ).unsqueeze (1 )
293
-
293
+
294
294
# pixel x & y
295
295
pixel_range = (torch .arange (width , device = self .device ) - 64 ) / 37
296
296
pixel_range = pixel_range .unsqueeze (0 ).unsqueeze (0 )
297
297
pixel_x = pixel_range .unsqueeze (- 2 ).expand (batch_size , 1 , width , - 1 )
298
298
pixel_y = pixel_range .unsqueeze (- 1 ).expand (batch_size , 1 , - 1 , height )
299
-
299
+
300
300
# Concat
301
301
sat_data = torch .cat ((sat_data , center_marker , x_coords , y_coords , pixel_x , pixel_y ), dim = 1 )
302
-
302
+
303
303
del center_marker , x_coords , y_coords , pixel_x , pixel_y
304
-
304
+
305
305
# Pass data through the network :)
306
306
out = F .relu (self .sat_conv1 (sat_data ))
307
307
out = self .maxpool (out )
308
308
out = F .relu (self .sat_conv2 (out ))
309
309
out = self .maxpool (out )
310
310
out = F .relu (self .sat_conv3 (out ))
311
-
311
+
312
312
out = out .view (- 1 , CHANNELS * 11 * 11 )
313
313
out = F .relu (self .fc1 (out ))
314
-
314
+
315
315
# *********************** NWP Data **************************************
316
316
nwp_data = x ['nwp' ].float () # Shape: batch_size, channel, seq_length, width, height
317
317
batch_size , n_nwp_chans , nwp_seq_len , nwp_width , nwp_height = nwp_data .shape
318
318
nwp_data = nwp_data .reshape (batch_size , n_nwp_chans * nwp_seq_len * nwp_width * nwp_height )
319
-
319
+
320
320
# Concat
321
321
out = torch .cat (
322
322
(
@@ -330,15 +330,15 @@ def forward(self, x):
330
330
),
331
331
dim = 1 )
332
332
del nwp_data
333
-
333
+
334
334
# Embedding of PV system ID
335
335
if EMBEDDING_DIM :
336
336
pv_embedding = self .pv_system_id_embedding (x ['pv_system_row_number' ])
337
337
out = torch .cat (
338
338
(
339
339
out ,
340
340
pv_embedding
341
- ),
341
+ ),
342
342
dim = 1 )
343
343
344
344
# Fully connected layers.
@@ -348,7 +348,7 @@ def forward(self, x):
348
348
out = F .relu (self .fc5 (out )) # PV yield is in range [0, 1]. ReLU should train more cleanly than sigmoid.
349
349
350
350
return out
351
-
351
+
352
352
def _training_or_validation_step (self , batch , is_train_step ):
353
353
y_hat = self (batch )
354
354
y = batch ['pv_yield' ][:, - self .forecast_len :]
@@ -360,19 +360,19 @@ def _training_or_validation_step(self, batch, is_train_step):
360
360
tag = "Train" if is_train_step else "Validation"
361
361
self .log_dict ({f'MSE/{ tag } ' : mse_loss }, on_step = is_train_step , on_epoch = True )
362
362
self .log_dict ({f'NMAE/{ tag } ' : nmae_loss }, on_step = is_train_step , on_epoch = True )
363
-
363
+
364
364
return nmae_loss
365
365
366
366
def training_step (self , batch , batch_idx ):
367
367
return self ._training_or_validation_step (batch , is_train_step = True )
368
-
368
+
369
369
def validation_step (self , batch , batch_idx ):
370
370
if batch_idx == 0 :
371
371
# Plot example
372
372
model_output = self (batch )
373
373
fig = plot_example (batch , model_output )
374
374
self .logger .experiment ['validation/plot' ].log (File .as_image (fig ))
375
-
375
+
376
376
return self ._training_or_validation_step (batch , is_train_step = False )
377
377
378
378
def configure_optimizers (self ):
@@ -436,4 +436,4 @@ def configure_optimizers(self):
436
436
trainer .fit (model , data_module )
437
437
438
438
439
- # In[ ]:
439
+ # In[ ]:
0 commit comments