Skip to content

Commit 4f0a3dd

Browse files
authored
Merge branch 'dev-data-sampler' into data_sampler
2 parents 292ad9d + 12dc2bf commit 4f0a3dd

File tree

10 files changed

+237
-174
lines changed

10 files changed

+237
-174
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[bumpversion]
22
commit = True
33
tag = True
4-
current_version = 3.0.53
4+
current_version = 3.0.63
55
message = Bump version: {current_version} → {new_version} [skip ci]
66

77
[bumpversion:file:pvnet/__init__.py]

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# PVNet 2.1
22

3-
[![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml)
3+
[![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml) [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
4+
45

56
This project is used for training PVNet and running PVNet on live data.
67

@@ -85,6 +86,8 @@ OCF maintains a Zarr formatted version of the German Weather Service's (DWD)
8586
ICON-EU NWP model here:
8687
https://huggingface.co/datasets/openclimatefix/dwd-icon-eu which includes the UK
8788

89+
Please note that the current version of [ICON loader]([url](https://github.com/openclimatefix/ocf_datapipes/blob/9ec252eeee44937c12ab52699579bdcace76e72f/ocf_datapipes/load/nwp/providers/icon.py#L9-L30)) supports a different format. If you want to use our ICON-EU dataset or your own NWP source, you can create a loader for it using [the instructions here]([url](https://github.com/openclimatefix/ocf_datapipes/tree/main/ocf_datapipes/load#nwp)).
90+
8891
**PV**\
8992
OCF maintains a dataset of PV generation from 1311 private PV installations
9093
here: https://huggingface.co/datasets/openclimatefix/uk_pv

experiments/analysis.py renamed to experiments/mae_analysis.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""
2-
Script to generate a table comparing two run for MAE values for 48 hour 15 minute forecast
2+
Script to generate analysis of MAE values for multiple model forecasts
3+
4+
Does this for 48 hour horizon forecasts with 15 minute granularity
5+
36
"""
47

58
import argparse
@@ -10,16 +13,23 @@
1013
import wandb
1114

1215

13-
def main(runs: list[str], run_names: list[str]) -> None:
16+
def main(project: str, runs: list[str], run_names: list[str]) -> None:
1417
"""
15-
Compare two runs for MAE values for 48 hour 15 minute forecast
18+
Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity
19+
20+
Args:
21+
project: name of W&B project
22+
runs: W&B ids of runs
23+
run_names: user specified names for runs
24+
1625
"""
1726
api = wandb.Api()
1827
dfs = []
28+
epoch_num = []
1929
for run in runs:
20-
run = api.run(f"openclimatefix/india/{run}")
30+
run = api.run(f"openclimatefix/{project}/{run}")
2131

22-
df = run.history()
32+
df = run.history(samples=run.lastHistoryStep + 1)
2333
# Get the columns that are in the format 'MAE_horizon/step_<number>/val`
2434
mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col]
2535
# Sort them
@@ -40,6 +50,7 @@ def main(runs: list[str], run_names: list[str]) -> None:
4050
# Get the step from the column name
4151
column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols]
4252
dfs.append(df)
53+
epoch_num.append(min_row_idx)
4354
# Get the timedelta for each group
4455
groupings = [
4556
[0, 0],
@@ -86,36 +97,41 @@ def main(runs: list[str], run_names: list[str]) -> None:
8697
for idx, df in enumerate(dfs):
8798
print(f"{run_names[idx]}: {df.mean()*100:0.3f}")
8899

89-
# Plot the error on per timestep, and all timesteps
100+
# Plot the error per timestep
90101
plt.figure()
91102
for idx, df in enumerate(dfs):
92-
plt.plot(column_timesteps, df, label=run_names[idx])
103+
plt.plot(
104+
column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-"
105+
)
93106
plt.legend()
94107
plt.xlabel("Timestep (minutes)")
95108
plt.ylabel("MAE %")
96109
plt.title("MAE % for each timestep")
97110
plt.savefig("mae_per_timestep.png")
98111
plt.show()
99112

100-
# Plot the error on per timestep, and grouped timesteps
113+
# Plot the error per grouped timestep
101114
plt.figure()
102-
for run_name in run_names:
103-
plt.plot(groups_df[run_name], label=run_name)
115+
for idx, run_name in enumerate(run_names):
116+
plt.plot(
117+
groups_df[run_name],
118+
label=f"{run_name}, epoch: {epoch_num[idx]}",
119+
marker="o",
120+
linestyle="-",
121+
)
104122
plt.legend()
105123
plt.xlabel("Timestep (minutes)")
106124
plt.ylabel("MAE %")
107-
plt.title("MAE % for each timestep")
108-
plt.savefig("mae_per_timestep.png")
125+
plt.title("MAE % for each grouped timestep")
126+
plt.savefig("mae_per_grouped_timestep.png")
109127
plt.show()
110128

111129

112130
if __name__ == "__main__":
113131
parser = argparse.ArgumentParser()
114-
"5llq8iw6"
115-
parser.add_argument("--first_run", type=str, default="xdlew7ib")
116-
parser.add_argument("--second_run", type=str, default="v3mja33d")
132+
parser.add_argument("--project", type=str, default="")
117133
# Add arguments that is a list of strings
118134
parser.add_argument("--list_of_runs", nargs="+")
119135
parser.add_argument("--run_names", nargs="+")
120136
args = parser.parse_args()
121-
main(args.list_of_runs, args.run_names)
137+
main(args.project, args.list_of_runs, args.run_names)

pvnet/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""PVNet"""
2-
__version__ = "3.0.53"
2+
__version__ = "3.0.63"

pvnet/models/base_model.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import logging
44
import os
5+
import tempfile
56
from pathlib import Path
67
from typing import Dict, Optional, Union
78

@@ -13,7 +14,7 @@
1314
import torch.nn.functional as F
1415
import wandb
1516
import yaml
16-
from huggingface_hub import ModelCard, ModelCardData
17+
from huggingface_hub import ModelCard, ModelCardData, PyTorchModelHubMixin
1718
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
1819
from huggingface_hub.file_download import hf_hub_download
1920
from huggingface_hub.hf_api import HfApi
@@ -144,7 +145,7 @@ def minimize_data_config(input_path, output_path, model):
144145
yaml.dump(config, outfile, default_flow_style=False)
145146

146147

147-
class PVNetModelHubMixin:
148+
class PVNetModelHubMixin(PyTorchModelHubMixin):
148149
"""
149150
Implementation of [`PyTorchModelHubMixin`] to provide model Hub upload/download capabilities.
150151
"""
@@ -415,7 +416,10 @@ def __init__(
415416
self.num_output_features = self.forecast_len * len(self.output_quantiles)
416417
else:
417418
self.num_output_features = self.forecast_len
418-
419+
420+
# save all validation results to array, so we can save these to weights n biases
421+
self.validation_epoch_results = []
422+
419423
def transfer_batch_to_device(self, batch, device, dataloader_idx):
420424
"""Method to move custom batches to a given device"""
421425
return copy_batch_to_device(batch, device)
@@ -605,12 +609,62 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p
605609
print(e)
606610
plt.close(fig)
607611

612+
def _log_validation_results(self, batch, y_hat, accum_batch_num):
613+
"""Append validation results to self.validation_epoch_results"""
614+
615+
# get truth values, shape (b, forecast_len)
616+
y = batch[self._target_key][:, -self.forecast_len :, 0]
617+
y = y.detach().cpu().numpy()
618+
batch_size = y.shape[0]
619+
620+
# get prediction values, shape (b, forecast_len, quantiles?)
621+
y_hat = y_hat.detach().cpu().numpy()
622+
623+
# get time_utc, shape (b, forecast_len)
624+
time_utc_key = BatchKey[f"{self._target_key_name}_time_utc"]
625+
time_utc = batch[time_utc_key][:, -self.forecast_len :].detach().cpu().numpy()
626+
627+
# get target id and change from (b,1) to (b,)
628+
id_key = BatchKey[f"{self._target_key_name}_id"]
629+
target_id = batch[id_key].detach().cpu().numpy()
630+
target_id = target_id.squeeze()
631+
632+
for i in range(batch_size):
633+
y_i = y[i]
634+
y_hat_i = y_hat[i]
635+
time_utc_i = time_utc[i]
636+
target_id_i = target_id[i]
637+
638+
results_dict = {
639+
"y": y_i,
640+
"time_utc": time_utc_i,
641+
}
642+
if self.use_quantile_regression:
643+
results_dict.update(
644+
{f"y_quantile_{q}": y_hat_i[:, i] for i, q in enumerate(self.output_quantiles)}
645+
)
646+
else:
647+
results_dict["y_hat"] = y_hat_i
648+
649+
results_df = pd.DataFrame(results_dict)
650+
results_df["id"] = target_id_i
651+
results_df["batch_idx"] = accum_batch_num
652+
results_df["example_idx"] = i
653+
654+
self.validation_epoch_results.append(results_df)
655+
608656
def validation_step(self, batch: dict, batch_idx):
609657
"""Run validation step"""
658+
659+
accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches
660+
610661
y_hat = self(batch)
611662

612663
y = batch[self._target_key][:, -self.forecast_len :]
613664

665+
if (batch_idx + 1) % self.trainer.accumulate_grad_batches == 0:
666+
self._log_validation_results(batch, y_hat, accum_batch_num)
667+
614668
# Expand persistence to be the same shape as y
615669
losses = self._calculate_common_losses(y, y_hat)
616670
losses.update(self._calculate_val_losses(y, y_hat))
@@ -628,8 +682,6 @@ def validation_step(self, batch: dict, batch_idx):
628682
on_epoch=True,
629683
)
630684

631-
accum_batch_num = batch_idx // self.trainer.accumulate_grad_batches
632-
633685
# Make plots only if using wandb logger
634686
if isinstance(self.logger, pl.loggers.WandbLogger) and accum_batch_num in [0, 1]:
635687
# Store these temporarily under self
@@ -671,6 +723,24 @@ def validation_step(self, batch: dict, batch_idx):
671723
def on_validation_epoch_end(self):
672724
"""Run on epoch end"""
673725

726+
try:
727+
# join together validation results, and save to wandb
728+
validation_results_df = pd.concat(self.validation_epoch_results)
729+
with tempfile.TemporaryDirectory() as tempdir:
730+
filename = os.path.join(tempdir, f"validation_results_{self.current_epoch}.csv")
731+
validation_results_df.to_csv(filename, index=False)
732+
733+
# make and log wand artifact
734+
validation_artifact = wandb.Artifact(
735+
f"validation_results_epoch_{self.current_epoch}", type="dataset"
736+
)
737+
validation_artifact.add_file(filename)
738+
wandb.log_artifact(validation_artifact)
739+
except Exception as e:
740+
print("Failed to log validation results to wandb")
741+
print(e)
742+
743+
self.validation_epoch_results = []
674744
horizon_maes_dict = self._horizon_maes.flush()
675745

676746
# Create the horizon accuracy curve

0 commit comments

Comments
 (0)