Skip to content

Commit 465e518

Browse files
committed
Merge branch 'data_sampler' of https://github.com/openclimatefix/PVNet into data_sampler
2 parents 87d5718 + db81147 commit 465e518

File tree

14 files changed

+294
-254
lines changed

14 files changed

+294
-254
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[bumpversion]
22
commit = True
33
tag = True
4-
current_version = 3.0.53
4+
current_version = 3.0.63
55
message = Bump version: {current_version} → {new_version} [skip ci]
66

77
[bumpversion:file:pvnet/__init__.py]

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# PVNet 2.1
22

3-
[![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml)
3+
[![Python Bump Version & release](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml/badge.svg)](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml) [![ease of contribution: hard](https://img.shields.io/badge/ease%20of%20contribution:%20hard-bb2629)](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
4+
45

56
This project is used for training PVNet and running PVNet on live data.
67

@@ -85,6 +86,8 @@ OCF maintains a Zarr formatted version of the German Weather Service's (DWD)
8586
ICON-EU NWP model here:
8687
https://huggingface.co/datasets/openclimatefix/dwd-icon-eu which includes the UK
8788

89+
Please note that the current version of [ICON loader]([url](https://github.com/openclimatefix/ocf_datapipes/blob/9ec252eeee44937c12ab52699579bdcace76e72f/ocf_datapipes/load/nwp/providers/icon.py#L9-L30)) supports a different format. If you want to use our ICON-EU dataset or your own NWP source, you can create a loader for it using [the instructions here]([url](https://github.com/openclimatefix/ocf_datapipes/tree/main/ocf_datapipes/load#nwp)).
90+
8891
**PV**\
8992
OCF maintains a dataset of PV generation from 1311 private PV installations
9093
here: https://huggingface.co/datasets/openclimatefix/uk_pv

experiments/analysis.py renamed to experiments/mae_analysis.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""
2-
Script to generate a table comparing two run for MAE values for 48 hour 15 minute forecast
2+
Script to generate analysis of MAE values for multiple model forecasts
3+
4+
Does this for 48 hour horizon forecasts with 15 minute granularity
5+
36
"""
47

58
import argparse
@@ -10,16 +13,23 @@
1013
import wandb
1114

1215

13-
def main(runs: list[str], run_names: list[str]) -> None:
16+
def main(project: str, runs: list[str], run_names: list[str]) -> None:
1417
"""
15-
Compare two runs for MAE values for 48 hour 15 minute forecast
18+
Compare MAE values for multiple model forecasts for 48 hour horizon with 15 minute granularity
19+
20+
Args:
21+
project: name of W&B project
22+
runs: W&B ids of runs
23+
run_names: user specified names for runs
24+
1625
"""
1726
api = wandb.Api()
1827
dfs = []
28+
epoch_num = []
1929
for run in runs:
20-
run = api.run(f"openclimatefix/india/{run}")
30+
run = api.run(f"openclimatefix/{project}/{run}")
2131

22-
df = run.history()
32+
df = run.history(samples=run.lastHistoryStep + 1)
2333
# Get the columns that are in the format 'MAE_horizon/step_<number>/val`
2434
mae_cols = [col for col in df.columns if "MAE_horizon/step_" in col and "val" in col]
2535
# Sort them
@@ -40,6 +50,7 @@ def main(runs: list[str], run_names: list[str]) -> None:
4050
# Get the step from the column name
4151
column_timesteps = [int(col.split("_")[-1].split("/")[0]) * 15 for col in mae_cols]
4252
dfs.append(df)
53+
epoch_num.append(min_row_idx)
4354
# Get the timedelta for each group
4455
groupings = [
4556
[0, 0],
@@ -86,36 +97,41 @@ def main(runs: list[str], run_names: list[str]) -> None:
8697
for idx, df in enumerate(dfs):
8798
print(f"{run_names[idx]}: {df.mean()*100:0.3f}")
8899

89-
# Plot the error on per timestep, and all timesteps
100+
# Plot the error per timestep
90101
plt.figure()
91102
for idx, df in enumerate(dfs):
92-
plt.plot(column_timesteps, df, label=run_names[idx])
103+
plt.plot(
104+
column_timesteps, df, label=f"{run_names[idx]}, epoch: {epoch_num[idx]}", linestyle="-"
105+
)
93106
plt.legend()
94107
plt.xlabel("Timestep (minutes)")
95108
plt.ylabel("MAE %")
96109
plt.title("MAE % for each timestep")
97110
plt.savefig("mae_per_timestep.png")
98111
plt.show()
99112

100-
# Plot the error on per timestep, and grouped timesteps
113+
# Plot the error per grouped timestep
101114
plt.figure()
102-
for run_name in run_names:
103-
plt.plot(groups_df[run_name], label=run_name)
115+
for idx, run_name in enumerate(run_names):
116+
plt.plot(
117+
groups_df[run_name],
118+
label=f"{run_name}, epoch: {epoch_num[idx]}",
119+
marker="o",
120+
linestyle="-",
121+
)
104122
plt.legend()
105123
plt.xlabel("Timestep (minutes)")
106124
plt.ylabel("MAE %")
107-
plt.title("MAE % for each timestep")
108-
plt.savefig("mae_per_timestep.png")
125+
plt.title("MAE % for each grouped timestep")
126+
plt.savefig("mae_per_grouped_timestep.png")
109127
plt.show()
110128

111129

112130
if __name__ == "__main__":
113131
parser = argparse.ArgumentParser()
114-
"5llq8iw6"
115-
parser.add_argument("--first_run", type=str, default="xdlew7ib")
116-
parser.add_argument("--second_run", type=str, default="v3mja33d")
132+
parser.add_argument("--project", type=str, default="")
117133
# Add arguments that is a list of strings
118134
parser.add_argument("--list_of_runs", nargs="+")
119135
parser.add_argument("--run_names", nargs="+")
120136
args = parser.parse_args()
121-
main(args.list_of_runs, args.run_names)
137+
main(args.project, args.list_of_runs, args.run_names)

pvnet/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""PVNet"""
2-
__version__ = "3.0.53"
2+
__version__ = "3.0.63"

pvnet/data/datamodule.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,35 @@
11
""" Data module for pytorch lightning """
2-
from datetime import datetime
32
from glob import glob
43

5-
from lightning.pytorch import LightningDataModule
6-
from torch.utils.data import Dataset, DataLoader
74
import torch
8-
9-
from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch, NumpyBatch
10-
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import (
11-
PVNetUKRegionalDataset
12-
)
5+
from lightning.pytorch import LightningDataModule
6+
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset
7+
from ocf_datapipes.batch import NumpyBatch, batch_to_tensor, stack_np_examples_into_batch
8+
from torch.utils.data import DataLoader, Dataset
139

1410

1511
class NumpybatchPremadeSamplesDataset(Dataset):
1612
"""Dataset to load NumpyBatch samples"""
17-
13+
1814
def __init__(self, sample_dir):
1915
"""Dataset to load NumpyBatch samples
20-
16+
2117
Args:
2218
sample_dir: Path to the directory of pre-saved samples.
2319
"""
2420
self.sample_paths = glob(f"{sample_dir}/*.pt")
25-
26-
21+
2722
def __len__(self):
2823
return len(self.sample_paths)
29-
24+
3025
def __getitem__(self, idx):
3126
return torch.load(self.sample_paths[idx])
3227

3328

3429
def collate_fn(samples: list[NumpyBatch]):
3530
"""Convert a list of NumpyBatch samples to a tensor batch"""
3631
return batch_to_tensor(stack_np_examples_into_batch(samples))
37-
32+
3833

3934
class DataModule(LightningDataModule):
4035
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""
@@ -46,9 +41,8 @@ def __init__(
4641
batch_size: int = 16,
4742
num_workers: int = 0,
4843
prefetch_factor: int | None = None,
49-
train_period: list[str|None] = [None, None],
50-
val_period: list[str|None] = [None, None],
51-
44+
train_period: list[str | None] = [None, None],
45+
val_period: list[str | None] = [None, None],
5246
):
5347
"""Datamodule for training pvnet architecture.
5448
@@ -67,7 +61,6 @@ def __init__(
6761
"""
6862
super().__init__()
6963

70-
7164
if not ((sample_dir is not None) ^ (configuration is not None)):
7265
raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.")
7366

@@ -100,21 +93,19 @@ def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
10093
def _get_premade_samples_dataset(self, subdir) -> Dataset:
10194
split_dir = f"{self.sample_dir}/{subdir}"
10295
return NumpybatchPremadeSamplesDataset(split_dir)
103-
96+
10497
def train_dataloader(self) -> DataLoader:
10598
"""Construct train dataloader"""
10699
if self.sample_dir is not None:
107100
dataset = self._get_premade_samples_dataset("train")
108101
else:
109102
dataset = self._get_streamed_samples_dataset(*self.train_period)
110103
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)
111-
104+
112105
def val_dataloader(self) -> DataLoader:
113106
"""Construct val dataloader"""
114107
if self.sample_dir is not None:
115108
dataset = self._get_premade_samples_dataset("val")
116109
else:
117110
dataset = self._get_streamed_samples_dataset(*self.val_period)
118111
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
119-
120-

0 commit comments

Comments
 (0)