Skip to content

Commit 12dc2bf

Browse files
committed
Merge branch 'main' of https://github.com/openclimatefix/PVNet into main
2 parents 16c9b5e + acb6a36 commit 12dc2bf

File tree

6 files changed

+149
-143
lines changed

6 files changed

+149
-143
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[bumpversion]
22
commit = True
33
tag = True
4-
current_version = 3.0.58
4+
current_version = 3.0.63
55
message = Bump version: {current_version} → {new_version} [skip ci]
66

77
[bumpversion:file:pvnet/__init__.py]

experiments/analysis.py renamed to experiments/mae_analysis.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""
2-
Script to generate a table comparing two run for MAE values for 48 hour 15 minute forecast
2+
Script to generate analysis of MAE values for multiple model forecasts
3+
4+
Does this for 48 hour horizon forecasts with 15 minute granularity
5+
36
"""
47

58
import argparse
@@ -10,15 +13,21 @@
1013
import wandb
1114

1215

13-
def main(runs: list[str], run_names: list[str]) -> None:
16+
def main(project: str, runs: list[str], run_names: list[str]) -> None:
1417
"""
15-
Compare two runs for MAE values for 48 hour 15 minute forecast
18+
Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity
19+
20+
Args:
21+
project: name of W&B project
22+
runs: W&B ids of runs
23+
run_names: user specified names for runs
24+
1625
"""
1726
api = wandb.Api()
1827
dfs = []
1928
epoch_num = []
2029
for run in runs:
21-
run = api.run(f"openclimatefix/PROJECT/{run}")
30+
run = api.run(f"openclimatefix/{project}/{run}")
2231

2332
df = run.history(samples=run.lastHistoryStep + 1)
2433
# Get the columns that are in the format 'MAE_horizon/step_<number>/val`
@@ -88,36 +97,41 @@ def main(runs: list[str], run_names: list[str]) -> None:
8897
for idx, df in enumerate(dfs):
8998
print(f"{run_names[idx]}: {df.mean()*100:0.3f}")
9099

91-
# Plot the error on per timestep, and all timesteps
100+
# Plot the error per timestep
92101
plt.figure()
93102
for idx, df in enumerate(dfs):
94-
plt.plot(column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}")
103+
plt.plot(
104+
column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-"
105+
)
95106
plt.legend()
96107
plt.xlabel("Timestep (minutes)")
97108
plt.ylabel("MAE %")
98109
plt.title("MAE % for each timestep")
99110
plt.savefig("mae_per_timestep.png")
100111
plt.show()
101112

102-
# Plot the error on per timestep, and grouped timesteps
113+
# Plot the error per grouped timestep
103114
plt.figure()
104115
for idx, run_name in enumerate(run_names):
105-
plt.plot(groups_df[run_name], label=f"{run_name}, epoch: {epoch_num[idx]}")
116+
plt.plot(
117+
groups_df[run_name],
118+
label=f"{run_name}, epoch: {epoch_num[idx]}",
119+
marker="o",
120+
linestyle="-",
121+
)
106122
plt.legend()
107123
plt.xlabel("Timestep (minutes)")
108124
plt.ylabel("MAE %")
109-
plt.title("MAE % for each timestep")
110-
plt.savefig("mae_per_timestep.png")
125+
plt.title("MAE % for each grouped timestep")
126+
plt.savefig("mae_per_grouped_timestep.png")
111127
plt.show()
112128

113129

114130
if __name__ == "__main__":
115131
parser = argparse.ArgumentParser()
116-
"5llq8iw6"
117-
parser.add_argument("--first_run", type=str, default="xdlew7ib")
118-
parser.add_argument("--second_run", type=str, default="v3mja33d")
132+
parser.add_argument("--project", type=str, default="")
119133
# Add arguments that is a list of strings
120134
parser.add_argument("--list_of_runs", nargs="+")
121135
parser.add_argument("--run_names", nargs="+")
122136
args = parser.parse_args()
123-
main(args.list_of_runs, args.run_names)
137+
main(args.project, args.list_of_runs, args.run_names)

pvnet/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""PVNet"""
2-
__version__ = "3.0.58"
2+
__version__ = "3.0.63"

pvnet/models/base_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,8 @@ def validation_step(self, batch: dict, batch_idx):
666666
# Sensor seems to be in batch, station, time order
667667
y = batch[self._target_key][:, -self.forecast_len :, 0]
668668

669-
self._log_validation_results(batch, y_hat, accum_batch_num)
669+
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
670+
self._log_validation_results(batch, y_hat, accum_batch_num)
670671

671672
# Expand persistence to be the same shape as y
672673
losses = self._calculate_common_losses(y, y_hat)
@@ -743,6 +744,7 @@ def on_validation_epoch_end(self):
743744
print("Failed to log validation results to wandb")
744745
print(e)
745746

747+
self.validation_epoch_results = []
746748
horizon_maes_dict = self._horizon_maes.flush()
747749

748750
# Create the horizon accuracy curve

scripts/backtest_sites.py

Lines changed: 116 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
except RuntimeError:
2424
pass
2525

26+
import json
2627
import logging
2728
import os
2829
import sys
@@ -32,6 +33,8 @@
3233
import pandas as pd
3334
import torch
3435
import xarray as xr
36+
from huggingface_hub import hf_hub_download
37+
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
3538
from ocf_datapipes.batch import (
3639
BatchKey,
3740
NumpyBatch,
@@ -50,26 +53,26 @@
5053
)
5154
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
5255
from omegaconf import DictConfig
53-
from torch.utils.data import DataLoader
56+
from torch.utils.data import DataLoader, IterDataPipe, functional_datapipe
5457
from torch.utils.data.datapipes.iter import IterableWrapper
5558
from tqdm import tqdm
5659

5760
from pvnet.load_model import get_model_from_checkpoints
5861
from pvnet.utils import SiteLocationLookup
5962

6063
# ------------------------------------------------------------------
61-
# USER CONFIGURED VARIABLES
64+
# USER CONFIGURED VARIABLES TO RUN THE SCRIPT
65+
66+
# Directory path to save results
6267
output_dir = "PLACEHOLDER"
6368

6469
# Local directory to load the PVNet checkpoint from. By default this should pull the best performing
6570
# checkpoint on the val set
6671
model_chckpoint_dir = "PLACEHOLDER"
6772

68-
# Local directory to load the summation model checkpoint from. By default this should pull the best
69-
# performing checkpoint on the val set. If set to None a simple sum is used instead
70-
# summation_chckpoint_dir = (
71-
# "/home/jamesfulton/repos/PVNet_summation/checkpoints/pvnet_summation/nw673nw2"
72-
# )
73+
hf_revision = None
74+
hf_token = None
75+
hf_model_id = None
7376

7477
# Forecasts will be made for all available init times between these
7578
start_datetime = "2022-05-08 00:00"
@@ -96,18 +99,96 @@
9699
# When sun as elevation below this, the forecast is set to zero
97100
MIN_DAY_ELEVATION = 0
98101

99-
# All pv system ids to produce forecasts for
102+
# Add all pv site ids here that you wish to produce forecasts for
100103
ALL_SITE_IDS = []
104+
# Need to be in ascending order
105+
ALL_SITE_IDS.sort()
101106

102107
# ------------------------------------------------------------------
103108
# FUNCTIONS
104109

105110

111+
@functional_datapipe("pad_forward_pv")
112+
class PadForwardPVIterDataPipe(IterDataPipe):
113+
"""
114+
Pads forecast pv.
115+
116+
Sun position is calculated based off of pv time index
117+
and for t0's close to end of pv data can have wrong shape as pv starts
118+
to run out of data to slice for the forecast part.
119+
"""
120+
121+
def __init__(
122+
self,
123+
pv_dp: IterDataPipe,
124+
forecast_duration: np.timedelta64,
125+
history_duration: np.timedelta64,
126+
time_resolution_minutes: np.timedelta64,
127+
):
128+
"""Init"""
129+
130+
super().__init__()
131+
self.pv_dp = pv_dp
132+
self.forecast_duration = forecast_duration
133+
self.history_duration = history_duration
134+
self.time_resolution_minutes = time_resolution_minutes
135+
136+
self.min_seq_length = history_duration // time_resolution_minutes
137+
138+
def __iter__(self):
139+
"""Iter"""
140+
141+
for xr_data in self.pv_dp:
142+
t_end = (
143+
xr_data.time_utc.data[0]
144+
+ self.history_duration
145+
+ self.forecast_duration
146+
+ self.time_resolution_minutes
147+
)
148+
time_idx = np.arange(xr_data.time_utc.data[0], t_end, self.time_resolution_minutes)
149+
150+
if len(xr_data.time_utc.data) < self.min_seq_length:
151+
raise ValueError("Not enough PV data to predict")
152+
153+
yield xr_data.reindex(time_utc=time_idx, fill_value=-1)
154+
155+
156+
def load_model_from_hf(model_id: str, revision: str, token: str):
157+
"""
158+
Loads model from HuggingFace
159+
"""
160+
161+
model_file = hf_hub_download(
162+
repo_id=model_id,
163+
filename=PYTORCH_WEIGHTS_NAME,
164+
revision=revision,
165+
token=token,
166+
)
167+
168+
# load config file
169+
config_file = hf_hub_download(
170+
repo_id=model_id,
171+
filename=CONFIG_NAME,
172+
revision=revision,
173+
token=token,
174+
)
175+
176+
with open(config_file, "r", encoding="utf-8") as f:
177+
config = json.load(f)
178+
179+
model = hydra.utils.instantiate(config)
180+
181+
state_dict = torch.load(model_file, map_location=torch.device("cuda"))
182+
model.load_state_dict(state_dict) # type: ignore
183+
model.eval() # type: ignore
184+
185+
return model
186+
187+
106188
def preds_to_dataarray(preds, model, valid_times, site_ids):
107189
"""Put numpy array of predictions into a dataarray"""
108190

109191
if model.use_quantile_regression:
110-
output_labels = model.output_quantiles
111192
output_labels = [f"forecast_mw_plevel_{int(q*100):02}" for q in model.output_quantiles]
112193
output_labels[output_labels.index("forecast_mw_plevel_50")] = "forecast_mw"
113194
else:
@@ -255,7 +336,8 @@ def get_loctimes_datapipes(config_path):
255336
unbatch_level=1
256337
) # might not need this part since the site datapipe is creating examples
257338

258-
# Create times datapipe so each worker receives 317 copies of the same datetime for its batch
339+
# Create times datapipe so each worker receives
340+
# len(ALL_SITE_IDS) copies of the same datetime for its batch
259341
t0_datapipe = IterableWrapper(
260342
[[t0 for site_id in ALL_SITE_IDS] for t0 in available_target_times]
261343
)
@@ -305,7 +387,7 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset:
305387
)
306388

307389
# Get effective capacities for this forecast
308-
# site_capacities = ds_site.nominal_capacity_wp.values
390+
site_capacities = self.ds_site.nominal_capacity_wp.values
309391
# Get the solar elevations. We need to un-normalise these from the values in the batch
310392
elevation = batch[BatchKey.pv_solar_elevation] * ELEVATION_STD + ELEVATION_MEAN
311393
# We only need elevation mask for forecasted values, not history
@@ -327,18 +409,17 @@ def predict_batch(self, batch: NumpyBatch) -> xr.Dataset:
327409
y_normed_site = model(device_batch).detach().cpu().numpy()
328410
da_normed_site = preds_to_dataarray(y_normed_site, model, valid_times, ALL_SITE_IDS)
329411

330-
# TODO fix this step: Multiply normalised forecasts by capacities and clip negatives
331-
# For now output normalised by capacity outputs and unnormalise in post processing
332-
# da_abs_site = da_normed_site.clip(0, None) * site_capacities[:, None, None]
333-
da_normed_site = da_normed_site.clip(0, None)
412+
# Multiply normalised forecasts by capacities and clip negatives
413+
da_abs_site = da_normed_site.clip(0, None) * site_capacities[:, None, None]
414+
334415
# Apply sundown mask
335-
da_normed_site = da_normed_site.where(~da_sundown_mask).fillna(0.0)
416+
da_abs_site = da_abs_site.where(~da_sundown_mask).fillna(0.0)
336417

337-
da_normed_site = da_normed_site.expand_dims(dim="init_time_utc", axis=0).assign_coords(
338-
init_time_utc=[t0]
418+
da_abs_site = da_abs_site.expand_dims(dim="init_time_utc", axis=0).assign_coords(
419+
init_time_utc=np.array([t0], dtype="datetime64[ns]")
339420
)
340421

341-
return da_normed_site
422+
return da_abs_site
342423

343424

344425
def get_datapipe(config_path: str) -> NumpyBatch:
@@ -364,6 +445,13 @@ def get_datapipe(config_path: str) -> NumpyBatch:
364445
t0_datapipe,
365446
)
366447

448+
config = load_yaml_configuration(config_path)
449+
data_pipeline["pv"] = data_pipeline["pv"].pad_forward_pv(
450+
forecast_duration=np.timedelta64(config.input_data.pv.forecast_minutes, "m"),
451+
history_duration=np.timedelta64(config.input_data.pv.history_minutes, "m"),
452+
time_resolution_minutes=np.timedelta64(config.input_data.pv.time_resolution_minutes, "m"),
453+
)
454+
367455
data_pipeline = DictDatasetIterDataPipe(
368456
{k: v for k, v in data_pipeline.items() if k != "config"},
369457
).map(split_dataset_dict_dp)
@@ -414,7 +502,13 @@ def main(config: DictConfig):
414502
# Create a dataloader for the concurrent batches and use multiprocessing
415503
dataloader = DataLoader(batch_pipe, **dataloader_kwargs)
416504
# Load the PVNet model
417-
model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
505+
if model_chckpoint_dir:
506+
model, *_ = get_model_from_checkpoints([model_chckpoint_dir], val_best=True)
507+
elif hf_model_id:
508+
model = load_model_from_hf(hf_model_id, hf_revision, hf_token)
509+
else:
510+
raise ValueError("Provide a model checkpoint or a HuggingFace model")
511+
418512
model = model.eval().to(device)
419513

420514
# Create object to make predictions for each input batch
@@ -428,13 +522,13 @@ def main(config: DictConfig):
428522

429523
t0 = ds_abs_all.init_time_utc.values[0]
430524

431-
# Save the predictioons
525+
# Save the predictions
432526
filename = f"{output_dir}/{t0}.nc"
433527
ds_abs_all.to_netcdf(filename)
434528

435529
pbar.update()
436530
except Exception as e:
437-
print(f"Exception {e} at {i}")
531+
print(f"Exception {e} at batch {i}")
438532
pass
439533

440534
# Close down

0 commit comments

Comments
 (0)