diff --git a/.all-contributorsrc b/.all-contributorsrc
new file mode 100644
index 00000000..6b3d3d2b
--- /dev/null
+++ b/.all-contributorsrc
@@ -0,0 +1,89 @@
+{
+ "files": [
+ "README.md"
+ ],
+ "imageSize": 100,
+ "commit": false,
+ "commitType": "docs",
+ "commitConvention": "angular",
+ "contributors": [
+ {
+ "login": "felix-e-h-p",
+ "name": "Felix",
+ "avatar_url": "https://avatars.githubusercontent.com/u/137530077?v=4",
+ "profile": "https://github.com/felix-e-h-p",
+ "contributions": [
+ "code"
+ ]
+ },
+ {
+ "login": "Sukh-P",
+ "name": "Sukhil Patel",
+ "avatar_url": "https://avatars.githubusercontent.com/u/42407101?v=4",
+ "profile": "https://github.com/Sukh-P",
+ "contributions": [
+ "code"
+ ]
+ },
+ {
+ "login": "dfulu",
+ "name": "James Fulton",
+ "avatar_url": "https://avatars.githubusercontent.com/u/41546094?v=4",
+ "profile": "https://github.com/dfulu",
+ "contributions": [
+ "code"
+ ]
+ },
+ {
+ "login": "AUdaltsova",
+ "name": "Alexandra Udaltsova",
+ "avatar_url": "https://avatars.githubusercontent.com/u/43303448?v=4",
+ "profile": "https://github.com/AUdaltsova",
+ "contributions": [
+ "code"
+ ]
+ },
+ {
+ "login": "zakwatts",
+ "name": "Megawattz",
+ "avatar_url": "https://avatars.githubusercontent.com/u/47150349?v=4",
+ "profile": "https://github.com/zakwatts",
+ "contributions": [
+ "code"
+ ]
+ },
+ {
+ "login": "peterdudfield",
+ "name": "Peter Dudfield",
+ "avatar_url": "https://avatars.githubusercontent.com/u/34686298?v=4",
+ "profile": "https://github.com/peterdudfield",
+ "contributions": [
+ "code"
+ ]
+ },
+ {
+ "login": "mahdilamb",
+ "name": "Mahdi Lamb",
+ "avatar_url": "https://avatars.githubusercontent.com/u/4696915?v=4",
+ "profile": "https://github.com/mahdilamb",
+ "contributions": [
+ "infra"
+ ]
+ },
+ {
+ "login": "jacobbieker",
+ "name": "Jacob Prince-Bieker",
+ "avatar_url": "https://avatars.githubusercontent.com/u/7170359?v=4",
+ "profile": "https://www.jacobbieker.com",
+ "contributions": [
+ "code"
+ ]
+ }
+ ],
+ "contributorsPerLine": 7,
+ "skipCi": true,
+ "repoType": "github",
+ "repoHost": "https://github.com",
+ "projectName": "PVNet",
+ "projectOwner": "openclimatefix"
+}
diff --git a/README.md b/README.md
index 69f54c64..4ca5b95e 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,7 @@
# PVNet 2.1
+
+[](#contributors-)
+
[](https://github.com/openclimatefix/PVNet/actions/workflows/release.yml) [](https://github.com/openclimatefix/ocf-meta-repo?tab=readme-ov-file#overview-of-ocfs-nowcasting-repositories)
@@ -231,3 +234,34 @@ If you have successfully trained a PVNet model and have a saved model checkpoint
## Testing
You can use `python -m pytest tests` to run tests
+
+## Contributors ✨
+
+Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)):
+
+
+
+
+
+
+
+
+
+
+
+This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!
diff --git a/experiments/mae_analysis.py b/experiments/mae_analysis.py
index ac01aed2..66f09024 100644
--- a/experiments/mae_analysis.py
+++ b/experiments/mae_analysis.py
@@ -7,11 +7,26 @@
import argparse
+import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb
+matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler(
+ color=[
+ "FFD053", # yellow
+ "7BCDF3", # blue
+ "63BCAF", # teal
+ "086788", # dark blue
+ "FF9736", # dark orange
+ "E4E4E4", # grey
+ "14120E", # black
+ "FFAC5F", # orange
+ "4C9A8E", # dark teal
+ ]
+)
+
def main(project: str, runs: list[str], run_names: list[str]) -> None:
"""
diff --git a/pvnet/data/uk_regional_datamodule.py b/pvnet/data/uk_regional_datamodule.py
index b8774668..9d7323e0 100644
--- a/pvnet/data/uk_regional_datamodule.py
+++ b/pvnet/data/uk_regional_datamodule.py
@@ -1,7 +1,7 @@
""" Data module for pytorch lightning """
from ocf_data_sampler.sample.uk_regional import UKRegionalSample
-from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
+from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
from torch.utils.data import Dataset
from pvnet.data.base_datamodule import BaseDataModule, PremadeSamplesDataset
diff --git a/pyproject.toml b/pyproject.toml
index 9f7d09e2..4968db60 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,7 +6,7 @@ dynamic = ["version", "readme"]
license={file="LICENCE"}
dependencies = [
- "ocf_data_sampler==0.0.54",
+ "ocf_data_sampler==0.1.2",
"ocf_datapipes>=3.3.34",
"ocf_ml_metrics>=0.0.11",
"numpy",
diff --git a/scripts/save_concurrent_batches.py b/scripts/save_concurrent_batches.py
deleted file mode 100644
index 37833b9e..00000000
--- a/scripts/save_concurrent_batches.py
+++ /dev/null
@@ -1,173 +0,0 @@
-"""
-Constructs batches where each batch includes all GSPs and only a single timestamp.
-
-Currently a slightly hacky implementation due to the way the configs are done. This script will use
-the same config file currently set to train the model. In the datamodule config it is possible
-to set the batch_output_dir and number of train/val batches, they can also be overriden in the
-command as shown in the example below.
-
-use:
-```
-python save_concurrent_batches.py \
- datamodule.batch_output_dir="/mnt/disks/nwp_rechunk/concurrent_batches_v3.9" \
- datamodule.num_train_batches=20_000 \
- datamodule.num_val_batches=4_000
-```
-
-"""
-# This is needed to get multiprocessing/multiple workers to behave
-try:
- import torch.multiprocessing as mp
-
- mp.set_start_method("spawn", force=True)
-except RuntimeError:
- pass
-
-import logging
-import os
-import shutil
-import sys
-import warnings
-
-import hydra
-import numpy as np
-import torch
-from ocf_datapipes.batch import BatchKey, batch_to_tensor
-from ocf_datapipes.training.pvnet_all_gsp import (
- construct_sliced_data_pipeline,
- construct_time_pipeline,
-)
-from omegaconf import DictConfig, OmegaConf
-from sqlalchemy import exc as sa_exc
-from torch.utils.data import DataLoader
-from torch.utils.data.datapipes.iter import IterableWrapper
-from tqdm import tqdm
-
-warnings.filterwarnings("ignore", category=sa_exc.SAWarning)
-
-logger = logging.getLogger(__name__)
-
-logging.basicConfig(stream=sys.stdout, level=logging.ERROR)
-
-
-class _save_batch_func_factory:
- def __init__(self, batch_dir):
- self.batch_dir = batch_dir
-
- def __call__(self, input):
- i, batch = input
- torch.save(batch, f"{self.batch_dir}/{i:06}.pt")
-
-
-def _get_datapipe(config_path, start_time, end_time, n_batches):
- t0_datapipe = construct_time_pipeline(
- config_path,
- start_time,
- end_time,
- )
-
- t0_datapipe = t0_datapipe.header(n_batches)
- t0_datapipe = t0_datapipe.sharding_filter()
-
- datapipe = construct_sliced_data_pipeline(
- config_path,
- t0_datapipe,
- ).map(batch_to_tensor)
-
- return datapipe
-
-
-def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs):
- save_func = _save_batch_func_factory(batch_dir)
- filenumber_pipe = IterableWrapper(np.arange(num_batches)).sharding_filter()
- save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func)
-
- dataloader = DataLoader(save_pipe, **dataloader_kwargs)
-
- pbar = tqdm(total=num_batches)
- for i, batch in zip(range(num_batches), dataloader):
- pbar.update()
- pbar.close()
- del dataloader
-
-
-def check_batch(batch):
- """Check if batch is valid concurrent batch for all GSPs"""
- # Check all GSP IDs are included and in correct order
- assert (batch[BatchKey.gsp_id].flatten().numpy() == np.arange(1, 318)).all()
- # Check all times are the same
- assert len(np.unique(batch[BatchKey.gsp_time_utc][:, 0].numpy())) == 1
- return batch
-
-
-@hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2")
-def main(config: DictConfig):
- """Constructs and saves validation and training batches."""
- config_dm = config.datamodule
-
- # Set up directory
- os.makedirs(config_dm.batch_output_dir, exist_ok=False)
-
- with open(f"{config_dm.batch_output_dir}/datamodule.yaml", "w") as f:
- f.write(OmegaConf.to_yaml(config.datamodule))
-
- shutil.copyfile(
- config_dm.configuration, f"{config_dm.batch_output_dir}/data_configuration.yaml"
- )
-
- dataloader_kwargs = dict(
- shuffle=False,
- batch_size=None, # batched in datapipe step
- sampler=None,
- batch_sampler=None,
- num_workers=config_dm.num_workers,
- collate_fn=None,
- pin_memory=False,
- drop_last=False,
- timeout=0,
- worker_init_fn=None,
- prefetch_factor=config_dm.prefetch_factor,
- persistent_workers=False,
- )
-
- if config_dm.num_val_batches > 0:
- print("----- Saving val batches -----")
-
- os.mkdir(f"{config_dm.batch_output_dir}/val")
-
- val_batch_pipe = _get_datapipe(
- config_dm.configuration,
- *config_dm.val_period,
- config_dm.num_val_batches,
- )
-
- _save_batches_with_dataloader(
- batch_pipe=val_batch_pipe,
- batch_dir=f"{config_dm.batch_output_dir}/val",
- num_batches=config_dm.num_val_batches,
- dataloader_kwargs=dataloader_kwargs,
- )
-
- if config_dm.num_train_batches > 0:
- print("----- Saving train batches -----")
-
- os.mkdir(f"{config_dm.batch_output_dir}/train")
-
- train_batch_pipe = _get_datapipe(
- config_dm.configuration,
- *config_dm.train_period,
- config_dm.num_train_batches,
- )
-
- _save_batches_with_dataloader(
- batch_pipe=train_batch_pipe,
- batch_dir=f"{config_dm.batch_output_dir}/train",
- num_batches=config_dm.num_train_batches,
- dataloader_kwargs=dataloader_kwargs,
- )
-
- print("done")
-
-
-if __name__ == "__main__":
- main()
diff --git a/scripts/save_concurrent_samples.py b/scripts/save_concurrent_samples.py
new file mode 100644
index 00000000..fc874d97
--- /dev/null
+++ b/scripts/save_concurrent_samples.py
@@ -0,0 +1,191 @@
+"""
+Constructs batches where each batch includes all GSPs and only a single timestamp.
+
+Currently a slightly hacky implementation due to the way the configs are done. This script will use
+the same config file currently set to train the model. In the datamodule config it is possible
+to set the batch_output_dir and number of train/val batches, they can also be overriden in the
+command as shown in the example below.
+
+use:
+```
+python save_concurrent_samples.py \
+ +datamodule.sample_output_dir="/mnt/disks/concurrent_batches/concurrent_samples_sat_pred_test" \
+ +datamodule.num_train_samples=20 \
+ +datamodule.num_val_samples=20
+```
+
+"""
+# Ensure this block of code runs only in the main process to avoid issues with worker processes.
+if __name__ == "__main__":
+ import torch.multiprocessing as mp
+
+ # Set the start method for torch multiprocessing. Choose either "forkserver" or "spawn" to be
+ # compatible with dask's multiprocessing.
+ mp.set_start_method("forkserver")
+
+ # Set the sharing strategy to 'file_system' to handle file descriptor limitations. This is
+ # important because libraries like Zarr may open many files, which can exhaust the file
+ # descriptor limit if too many workers are used.
+ mp.set_sharing_strategy("file_system")
+
+
+import os
+import sys
+import shutil
+from tqdm import tqdm
+
+import warnings
+import logging
+from sqlalchemy import exc as sa_exc
+
+import hydra
+from omegaconf import DictConfig, OmegaConf
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, Dataset
+from ocf_datapipes.batch import batch_to_tensor
+from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
+
+from pvnet.utils import print_config
+
+
+# ------- filter warning and set up config -------
+
+warnings.filterwarnings("ignore", category=sa_exc.SAWarning)
+
+logger = logging.getLogger(__name__)
+
+logging.basicConfig(stream=sys.stdout, level=logging.ERROR)
+
+# -------------------------------------------------
+
+
+class SaveFuncFactory:
+ """Factory for creating a function to save a sample to disk."""
+
+ def __init__(self, save_dir: str):
+ """Factory for creating a function to save a sample to disk."""
+ self.save_dir = save_dir
+
+ def __call__(self, sample, sample_num: int):
+ """Save a sample to disk"""
+ torch.save(sample, f"{self.save_dir}/{sample_num:08}.pt")
+
+
+def save_samples_with_dataloader(
+ dataset: Dataset,
+ save_dir: str,
+ num_samples: int,
+ dataloader_kwargs: dict,
+) -> None:
+ """Save samples from a dataset using a dataloader."""
+ save_func = SaveFuncFactory(save_dir)
+
+ dataloader = DataLoader(dataset, **dataloader_kwargs)
+
+ pbar = tqdm(total=num_samples)
+ for i, sample in zip(range(num_samples), dataloader):
+ check_sample(sample)
+ save_func(sample, i)
+ pbar.update()
+ pbar.close()
+
+
+def check_sample(sample):
+ """Check if sample is valid concurrent batch for all GSPs"""
+ # Check all GSP IDs are included and in correct order
+ assert (sample["gsp_id"].flatten().numpy() == np.arange(1, 318)).all()
+ # Check all times are the same
+ assert len(np.unique(sample["gsp_time_utc"][:, 0].numpy())) == 1
+
+
+@hydra.main(config_path="../configs/", config_name="config.yaml", version_base="1.2")
+def main(config: DictConfig) -> None:
+ """Constructs and saves validation and training samples."""
+ config_dm = config.datamodule
+
+ print_config(config, resolve=False)
+
+ # Set up directory
+ os.makedirs(config_dm.sample_output_dir, exist_ok=False)
+
+ # Copy across configs which define the samples into the new sample directory
+ with open(f"{config_dm.sample_output_dir}/datamodule.yaml", "w") as f:
+ f.write(OmegaConf.to_yaml(config_dm))
+
+ shutil.copyfile(
+ config_dm.configuration, f"{config_dm.sample_output_dir}/data_configuration.yaml"
+ )
+
+ # Define the keywargs going into the train and val dataloaders
+ dataloader_kwargs = dict(
+ shuffle=True,
+ batch_size=None,
+ sampler=None,
+ batch_sampler=None,
+ num_workers=config_dm.num_workers,
+ collate_fn=None,
+ pin_memory=False, # Only using CPU to prepare samples so pinning is not beneficial
+ drop_last=False,
+ timeout=0,
+ worker_init_fn=None,
+ prefetch_factor=config_dm.prefetch_factor,
+ persistent_workers=False, # Not needed since we only enter the dataloader loop once
+ )
+
+ if config_dm.num_val_samples > 0:
+ print("----- Saving val samples -----")
+
+ val_output_dir = f"{config_dm.sample_output_dir}/val"
+
+ # Make directory for val samples
+ os.mkdir(val_output_dir)
+
+ # Get the dataset
+ val_dataset = PVNetUKConcurrentDataset(
+ config_dm.configuration,
+ start_time=config_dm.val_period[0],
+ end_time=config_dm.val_period[1]
+ )
+
+ # Save samples
+ save_samples_with_dataloader(
+ dataset=val_dataset,
+ save_dir=val_output_dir,
+ num_samples=config_dm.num_val_samples,
+ dataloader_kwargs=dataloader_kwargs,
+ )
+
+ del val_dataset
+
+ if config_dm.num_train_samples > 0:
+ print("----- Saving train samples -----")
+
+ train_output_dir = f"{config_dm.sample_output_dir}/train"
+
+ # Make directory for train samples
+ os.mkdir(train_output_dir)
+
+ # Get the dataset
+ train_dataset = PVNetUKConcurrentDataset(
+ config_dm.configuration,
+ start_time=config_dm.train_period[0],
+ end_time=config_dm.train_period[1]
+ )
+
+ # Save samples
+ save_samples_with_dataloader(
+ dataset=train_dataset,
+ save_dir=train_output_dir,
+ num_samples=config_dm.num_train_samples,
+ dataloader_kwargs=dataloader_kwargs,
+ )
+
+ del train_dataset
+
+ print("----- Saving complete -----")
+
+
+if __name__ == "__main__":
+ main()