23
23
except RuntimeError :
24
24
pass
25
25
26
+ import json
26
27
import logging
27
28
import os
28
29
import sys
32
33
import pandas as pd
33
34
import torch
34
35
import xarray as xr
36
+ from huggingface_hub import hf_hub_download
37
+ from huggingface_hub .constants import CONFIG_NAME , PYTORCH_WEIGHTS_NAME
35
38
from ocf_datapipes .batch import (
36
39
BatchKey ,
37
40
NumpyBatch ,
50
53
)
51
54
from ocf_datapipes .utils .consts import ELEVATION_MEAN , ELEVATION_STD
52
55
from omegaconf import DictConfig
53
- from torch .utils .data import DataLoader
56
+ from torch .utils .data import DataLoader , IterDataPipe , functional_datapipe
54
57
from torch .utils .data .datapipes .iter import IterableWrapper
55
58
from tqdm import tqdm
56
59
57
60
from pvnet .load_model import get_model_from_checkpoints
58
61
from pvnet .utils import SiteLocationLookup
59
62
60
63
# ------------------------------------------------------------------
61
- # USER CONFIGURED VARIABLES
64
+ # USER CONFIGURED VARIABLES TO RUN THE SCRIPT
65
+
66
+ # Directory path to save results
62
67
output_dir = "PLACEHOLDER"
63
68
64
69
# Local directory to load the PVNet checkpoint from. By default this should pull the best performing
65
70
# checkpoint on the val set
66
71
model_chckpoint_dir = "PLACEHOLDER"
67
72
68
- # Local directory to load the summation model checkpoint from. By default this should pull the best
69
- # performing checkpoint on the val set. If set to None a simple sum is used instead
70
- # summation_chckpoint_dir = (
71
- # "/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/nw673nw2"
72
- # )
73
+ hf_revision = None
74
+ hf_token = None
75
+ hf_model_id = None
73
76
74
77
# Forecasts will be made for all available init times between these
75
78
start_datetime = "2022-05-08 00:00"
96
99
# When sun as elevation below this, the forecast is set to zero
97
100
MIN_DAY_ELEVATION = 0
98
101
99
- # All pv system ids to produce forecasts for
102
+ # Add all pv site ids here that you wish to produce forecasts for
100
103
ALL_SITE_IDS = []
104
+ # Need to be in ascending order
105
+ ALL_SITE_IDS .sort ()
101
106
102
107
# ------------------------------------------------------------------
103
108
# FUNCTIONS
104
109
105
110
111
+ @functional_datapipe ("pad_forward_pv" )
112
+ class PadForwardPVIterDataPipe (IterDataPipe ):
113
+ """
114
+ Pads forecast pv.
115
+
116
+ Sun position is calculated based off of pv time index
117
+ and for t0's close to end of pv data can have wrong shape as pv starts
118
+ to run out of data to slice for the forecast part.
119
+ """
120
+
121
+ def __init__ (
122
+ self ,
123
+ pv_dp : IterDataPipe ,
124
+ forecast_duration : np .timedelta64 ,
125
+ history_duration : np .timedelta64 ,
126
+ time_resolution_minutes : np .timedelta64 ,
127
+ ):
128
+ """Init"""
129
+
130
+ super ().__init__ ()
131
+ self .pv_dp = pv_dp
132
+ self .forecast_duration = forecast_duration
133
+ self .history_duration = history_duration
134
+ self .time_resolution_minutes = time_resolution_minutes
135
+
136
+ self .min_seq_length = history_duration // time_resolution_minutes
137
+
138
+ def __iter__ (self ):
139
+ """Iter"""
140
+
141
+ for xr_data in self .pv_dp :
142
+ t_end = (
143
+ xr_data .time_utc .data [0 ]
144
+ + self .history_duration
145
+ + self .forecast_duration
146
+ + self .time_resolution_minutes
147
+ )
148
+ time_idx = np .arange (xr_data .time_utc .data [0 ], t_end , self .time_resolution_minutes )
149
+
150
+ if len (xr_data .time_utc .data ) < self .min_seq_length :
151
+ raise ValueError ("Not enough PV data to predict" )
152
+
153
+ yield xr_data .reindex (time_utc = time_idx , fill_value = - 1 )
154
+
155
+
156
+ def load_model_from_hf (model_id : str , revision : str , token : str ):
157
+ """
158
+ Loads model from HuggingFace
159
+ """
160
+
161
+ model_file = hf_hub_download (
162
+ repo_id = model_id ,
163
+ filename = PYTORCH_WEIGHTS_NAME ,
164
+ revision = revision ,
165
+ token = token ,
166
+ )
167
+
168
+ # load config file
169
+ config_file = hf_hub_download (
170
+ repo_id = model_id ,
171
+ filename = CONFIG_NAME ,
172
+ revision = revision ,
173
+ token = token ,
174
+ )
175
+
176
+ with open (config_file , "r" , encoding = "utf-8" ) as f :
177
+ config = json .load (f )
178
+
179
+ model = hydra .utils .instantiate (config )
180
+
181
+ state_dict = torch .load (model_file , map_location = torch .device ("cuda" ))
182
+ model .load_state_dict (state_dict ) # type: ignore
183
+ model .eval () # type: ignore
184
+
185
+ return model
186
+
187
+
106
188
def preds_to_dataarray (preds , model , valid_times , site_ids ):
107
189
"""Put numpy array of predictions into a dataarray"""
108
190
109
191
if model .use_quantile_regression :
110
- output_labels = model .output_quantiles
111
192
output_labels = [f"forecast_mw_plevel_{ int (q * 100 ):02} " for q in model .output_quantiles ]
112
193
output_labels [output_labels .index ("forecast_mw_plevel_50" )] = "forecast_mw"
113
194
else :
@@ -255,7 +336,8 @@ def get_loctimes_datapipes(config_path):
255
336
unbatch_level = 1
256
337
) # might not need this part since the site datapipe is creating examples
257
338
258
- # Create times datapipe so each worker receives 317 copies of the same datetime for its batch
339
+ # Create times datapipe so each worker receives
340
+ # len(ALL_SITE_IDS) copies of the same datetime for its batch
259
341
t0_datapipe = IterableWrapper (
260
342
[[t0 for site_id in ALL_SITE_IDS ] for t0 in available_target_times ]
261
343
)
@@ -305,7 +387,7 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset:
305
387
)
306
388
307
389
# Get effective capacities for this forecast
308
- # site_capacities = ds_site.nominal_capacity_wp.values
390
+ site_capacities = self . ds_site .nominal_capacity_wp .values
309
391
# Get the solar elevations. We need to un-normalise these from the values in the batch
310
392
elevation = batch [BatchKey .pv_solar_elevation ] * ELEVATION_STD + ELEVATION_MEAN
311
393
# We only need elevation mask for forecasted values, not history
@@ -327,18 +409,17 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset:
327
409
y_normed_site = model (device_batch ).detach ().cpu ().numpy ()
328
410
da_normed_site = preds_to_dataarray (y_normed_site , model , valid_times , ALL_SITE_IDS )
329
411
330
- # TODO fix this step: Multiply normalised forecasts by capacities and clip negatives
331
- # For now output normalised by capacity outputs and unnormalise in post processing
332
- # da_abs_site = da_normed_site.clip(0, None) * site_capacities[:, None, None]
333
- da_normed_site = da_normed_site .clip (0 , None )
412
+ # Multiply normalised forecasts by capacities and clip negatives
413
+ da_abs_site = da_normed_site .clip (0 , None ) * site_capacities [:, None , None ]
414
+
334
415
# Apply sundown mask
335
- da_normed_site = da_normed_site .where (~ da_sundown_mask ).fillna (0.0 )
416
+ da_abs_site = da_abs_site .where (~ da_sundown_mask ).fillna (0.0 )
336
417
337
- da_normed_site = da_normed_site .expand_dims (dim = "init_time_utc" , axis = 0 ).assign_coords (
338
- init_time_utc = [t0 ]
418
+ da_abs_site = da_abs_site .expand_dims (dim = "init_time_utc" , axis = 0 ).assign_coords (
419
+ init_time_utc = np . array ( [t0 ], dtype = "datetime64[ns]" )
339
420
)
340
421
341
- return da_normed_site
422
+ return da_abs_site
342
423
343
424
344
425
def get_datapipe (config_path : str ) -> NumpyBatch :
@@ -364,6 +445,13 @@ def get_datapipe(config_path: str) -> NumpyBatch:
364
445
t0_datapipe ,
365
446
)
366
447
448
+ config = load_yaml_configuration (config_path )
449
+ data_pipeline ["pv" ] = data_pipeline ["pv" ].pad_forward_pv (
450
+ forecast_duration = np .timedelta64 (config .input_data .pv .forecast_minutes , "m" ),
451
+ history_duration = np .timedelta64 (config .input_data .pv .history_minutes , "m" ),
452
+ time_resolution_minutes = np .timedelta64 (config .input_data .pv .time_resolution_minutes , "m" ),
453
+ )
454
+
367
455
data_pipeline = DictDatasetIterDataPipe (
368
456
{k : v for k , v in data_pipeline .items () if k != "config" },
369
457
).map (split_dataset_dict_dp )
@@ -414,7 +502,13 @@ def main(config: DictConfig):
414
502
# Create a dataloader for the concurrent batches and use multiprocessing
415
503
dataloader = DataLoader (batch_pipe , ** dataloader_kwargs )
416
504
# Load the PVNet model
417
- model , * _ = get_model_from_checkpoints ([model_chckpoint_dir ], val_best = True )
505
+ if model_chckpoint_dir :
506
+ model , * _ = get_model_from_checkpoints ([model_chckpoint_dir ], val_best = True )
507
+ elif hf_model_id :
508
+ model = load_model_from_hf (hf_model_id , hf_revision , hf_token )
509
+ else :
510
+ raise ValueError ("Provide a model checkpoint or a HuggingFace model" )
511
+
418
512
model = model .eval ().to (device )
419
513
420
514
# Create object to make predictions for each input batch
@@ -428,13 +522,13 @@ def main(config: DictConfig):
428
522
429
523
t0 = ds_abs_all .init_time_utc .values [0 ]
430
524
431
- # Save the predictioons
525
+ # Save the predictions
432
526
filename = f"{ output_dir } /{ t0 } .nc"
433
527
ds_abs_all .to_netcdf (filename )
434
528
435
529
pbar .update ()
436
530
except Exception as e :
437
- print (f"Exception { e } at { i } " )
531
+ print (f"Exception { e } at batch { i } " )
438
532
pass
439
533
440
534
# Close down
0 commit comments