Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade UK to data sampler #276

Merged
merged 19 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 125 additions & 44 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,138 @@
""" Data module for pytorch lightning """
from datetime import datetime
from glob import glob

import resource

from lightning.pytorch import LightningDataModule
from torch.utils.data import Dataset, DataLoader
import torch
from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch
from ocf_datapipes.training.pvnet import pvnet_datapipe
from torch.utils.data.datapipes.iter import FileLister

from pvnet.data.base import BaseDataModule
from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch, NumpyBatch
from ocf_data_sampler.torch_datasets.pvnet_uk_regional import (
PVNetUKRegionalDataset
)


def fill_nans_in_arrays(batch):
"""Fills all NaN values in each np.ndarray in the batch dictionary with zeros.

Operation is performed in-place on the batch.
"""
for k, v in batch.items():
if isinstance(v, torch.Tensor):
if torch.isnan(v).any():
batch[k] = torch.nan_to_num(v, nan=0.0)

# Recursion is included to reach NWP arrays in subdict
elif isinstance(v, dict):
fill_nans_in_arrays(v)

rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))
return batch


class DataModule(BaseDataModule):

class NumpybatchPremadeSamplesDataset(Dataset):
"""Dataset to load NumpyBatch samples"""

def __init__(self, sample_dir):
"""Dataset to load NumpyBatch samples

Args:
sample_dir: Path to the directory of pre-saved samples.
"""
self.sample_paths = glob(f"{sample_dir}/*.pt")


def __len__(self):
return len(self.sample_paths)

def __getitem__(self, idx):
return fill_nans_in_arrays(torch.load(self.sample_paths[idx]))


def collate_fn(samples: list[NumpyBatch]):
"""Convert a list of NumpyBatch samples to a tensor batch"""
return batch_to_tensor(stack_np_examples_into_batch(samples))


class DataModule(LightningDataModule):
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""

def _get_datapipe(self, start_time, end_time):
data_pipeline = pvnet_datapipe(
self.configuration,
start_time=start_time,
end_time=end_time,
)
def __init__(
self,
configuration: str | None = None,
sample_dir: str | None = None,
batch_size: int = 16,
num_workers: int = 0,
prefetch_factor: int | None = None,
train_period: list[str|None] = [None, None],
val_period: list[str|None] = [None, None],

):
"""Datamodule for training pvnet architecture.

Can also be used with pre-made batches if `sample_dir` is set.

data_pipeline = (
data_pipeline.batch(self.batch_size)
.map(stack_np_examples_into_batch)
.map(batch_to_tensor)
Args:
configuration: Path to datapipe configuration file.
sample_dir: Path to the directory of pre-saved samples. Cannot be used together with
`configuration` or '[train/val]_period'.
batch_size: Batch size.
num_workers: Number of workers to use in multiprocess batch loading.
prefetch_factor: Number of data will be prefetched at the end of each worker process.
train_period: Date range filter for train dataloader.
val_period: Date range filter for val dataloader.

"""
super().__init__()


if not ((sample_dir is not None) ^ (configuration is not None)):
raise ValueError("Exactly one of `sample_dir` or `configuration` must be set.")

if sample_dir is not None:
if any([period != [None, None] for period in [train_period, val_period]]):
raise ValueError("Cannot set `(train/val)_period` with presaved samples")

self.configuration = configuration
self.sample_dir = sample_dir
self.train_period = train_period
self.val_period = val_period

self._common_dataloader_kwargs = dict(
batch_size=batch_size,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
)
return data_pipeline

def _get_premade_batches_datapipe(self, subdir, shuffle=False):
data_pipeline = FileLister(f"{self.batch_dir}/{subdir}", masks="*.pt", recursive=False)
if shuffle:
data_pipeline = (
data_pipeline.shuffle(buffer_size=10_000)
.sharding_filter()
.map(torch.load)
# Split the batches and reshuffle them to be combined into new batches
.split_batches()
.shuffle(buffer_size=self.shuffle_factor * self.batch_size)
)

def _get_streamed_samples_dataset(self, start_time, end_time) -> Dataset:
return PVNetUKRegionalDataset(self.configuration, start_time=start_time, end_time=end_time)

def _get_premade_samples_dataset(self, subdir) -> Dataset:
split_dir = f"{self.sample_dir}/{subdir}"
return NumpybatchPremadeSamplesDataset(split_dir)

def train_dataloader(self) -> DataLoader:
"""Construct train dataloader"""
if self.sample_dir is not None:
dataset = self._get_premade_samples_dataset("train")
else:
data_pipeline = (
data_pipeline.sharding_filter().map(torch.load)
# Split the batches so we can use any batch-size
.split_batches()
)

data_pipeline = (
data_pipeline.batch(self.batch_size)
.map(stack_np_examples_into_batch)
.map(batch_to_tensor)
)
dataset = self._get_streamed_samples_dataset(*self.train_period)
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)

def val_dataloader(self) -> DataLoader:
"""Construct val dataloader"""
if self.sample_dir is not None:
dataset = self._get_premade_samples_dataset("val")
else:
dataset = self._get_streamed_samples_dataset(*self.val_period)
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)

return data_pipeline

83 changes: 24 additions & 59 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,21 @@
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi

from ocf_datapipes.batch import BatchKey
from ocf_datapipes.batch import copy_batch_to_device

from ocf_ml_metrics.evaluation.evaluation import evaluation

from pvnet.models.utils import (
BatchAccumulator,
MetricAccumulator,
PredAccumulator,
WeightedLosses,
)
from pvnet.optimizers import AbstractOptimizer
from pvnet.utils import construct_ocf_ml_metrics_batch_df, plot_batch_forecasts
from pvnet.utils import plot_batch_forecasts



DATA_CONFIG_NAME = "data_config.yaml"

Expand Down Expand Up @@ -93,7 +97,7 @@ def minimize_data_config(input_path, output_path, model):
if not model.include_nwp:
del config["input_data"]["nwp"]
else:
for nwp_source in config["input_data"]["nwp"].keys():
for nwp_source in list(config["input_data"]["nwp"].keys()):
nwp_config = config["input_data"]["nwp"][nwp_source]

if nwp_source not in model.nwp_encoders_dict:
Expand Down Expand Up @@ -234,6 +238,13 @@ def get_data_config(
)

return data_config_file


def _save_pretrained(self, save_directory: Path) -> None:
"""Save weights from a Pytorch model to a local directory."""
model_to_save = self.module if hasattr(self, "module") else self # type: ignore
torch.save(model_to_save.state_dict(), save_directory / PYTORCH_WEIGHTS_NAME)


def save_pretrained(
self,
Expand Down Expand Up @@ -347,7 +358,6 @@ def __init__(
target_key: str = "gsp",
interval_minutes: int = 30,
timestep_intervals_to_plot: Optional[list[int]] = None,
use_weighted_loss: bool = False,
forecast_minutes_ignore: Optional[int] = 0,
):
"""Abtstract base class for PVNet submodels.
Expand All @@ -361,7 +371,6 @@ def __init__(
target_key: The key of the target variable in the batch
interval_minutes: The interval in minutes between each timestep in the data
timestep_intervals_to_plot: Intervals, in timesteps, to plot during training
use_weighted_loss: Whether to use a weighted loss function
forecast_minutes_ignore: Number of forecast minutes to ignore when calculating losses.
For example if set to 60, the model doesnt predict the first 60 minutes
"""
Expand Down Expand Up @@ -393,23 +402,24 @@ def __init__(
self.forecast_len = (forecast_minutes - forecast_minutes_ignore) // interval_minutes
self.forecast_len_ignore = forecast_minutes_ignore // interval_minutes

self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len)

self._accumulated_metrics = MetricAccumulator()
self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key_name)
self._accumulated_y_hat = PredAccumulator()
self._horizon_maes = MetricAccumulator()

# Store whether the model should use quantile regression or simply predict the mean
self.use_quantile_regression = self.output_quantiles is not None
self.use_weighted_loss = use_weighted_loss

# Store the number of ouput features that the model should predict for
if self.use_quantile_regression:
self.num_output_features = self.forecast_len * len(self.output_quantiles)
else:
self.num_output_features = self.forecast_len


def transfer_batch_to_device(self, batch, device, dataloader_idx):
"""Method to move custom batches to a given device"""
return copy_batch_to_device(batch, device)

def _quantiles_to_prediction(self, y_quantiles):
"""
Convert network prediction into a point prediction.
Expand Down Expand Up @@ -451,13 +461,11 @@ def _calculate_quantile_loss(self, y_quantiles, y):
errors = y - y_quantiles[..., i]
losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
losses = 2 * torch.cat(losses, dim=2)
if self.use_weighted_loss:
weights = self.weighted_losses.weights.unsqueeze(1).unsqueeze(0).to(y.device)
losses = losses * weights

return losses.mean()

def _calculate_common_losses(self, y, y_hat):
"""Calculate losses common to train, test, and val"""
"""Calculate losses common to train, and val"""

losses = {}

Expand All @@ -469,19 +477,13 @@ def _calculate_common_losses(self, y, y_hat):
mse_loss = F.mse_loss(y_hat, y)
mae_loss = F.l1_loss(y_hat, y)

# calculate mse, mae with exp weighted loss
mse_exp = self.weighted_losses.get_mse_exp(output=y_hat, target=y)
mae_exp = self.weighted_losses.get_mae_exp(output=y_hat, target=y)

# TODO: Compute correlation coef using np.corrcoef(tensor with
# shape (2, num_timesteps))[0, 1] on each example, and taking
# the mean across the batch?
losses.update(
{
"MSE": mse_loss,
"MAE": mae_loss,
"MSE_EXP": mse_exp,
"MAE_EXP": mae_exp,
}
)

Expand Down Expand Up @@ -527,12 +529,6 @@ def _calculate_val_losses(self, y, y_hat):
losses.update(self._step_mae_and_mse(y, y_persist, dict_key_root="persistence"))
return losses

def _calculate_test_losses(self, y, y_hat):
"""Calculate additional test losses"""
# No additional test losses
losses = {}
return losses

def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
"""Internal function to accumulate training batches and log results.

Expand Down Expand Up @@ -578,7 +574,7 @@ def _training_accumulate_log(self, batch, batch_idx, losses, y_hat):
def training_step(self, batch, batch_idx):
"""Run training step"""
y_hat = self(batch)
y = batch[self._target_key][:, -self.forecast_len :, 0]
y = batch[self._target_key][:, -self.forecast_len :]

losses = self._calculate_common_losses(y, y_hat)
losses = {f"{k}/train": v for k, v in losses.items()}
Expand Down Expand Up @@ -612,8 +608,8 @@ def _log_forecast_plot(self, batch, y_hat, accum_batch_num, timesteps_to_plot, p
def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""
y_hat = self(batch)
# Sensor seems to be in batch, station, time order
y = batch[self._target_key][:, -self.forecast_len :, 0]

y = batch[self._target_key][:, -self.forecast_len :]

# Expand persistence to be the same shape as y
losses = self._calculate_common_losses(y, y_hat)
Expand Down Expand Up @@ -693,37 +689,6 @@ def on_validation_epoch_end(self):
print("Failed to log horizon_loss_curve to wandb")
print(e)

def test_step(self, batch, batch_idx):
"""Run test step"""
y_hat = self(batch)
y = batch[self._target_key][:, -self.forecast_len :, 0]

losses = self._calculate_common_losses(y, y_hat)
losses.update(self._calculate_val_losses(y, y_hat))
losses.update(self._calculate_test_losses(y, y_hat))
logged_losses = {f"{k}/test": v for k, v in losses.items()}

self.log_dict(
logged_losses,
on_step=False,
on_epoch=True,
)

if self.use_quantile_regression:
y_hat = self._quantiles_to_prediction(y_hat)

return construct_ocf_ml_metrics_batch_df(batch, y, y_hat)

def on_test_epoch_end(self, outputs):
"""Evalauate test results using oc_ml_metrics"""
results_df = pd.concat(outputs)
# setting model_name="test" gives us keys like "test/mw/forecast_horizon_30_minutes/mae"
metrics = evaluation(results_df=results_df, model_name="test", outturn_unit="mw")

self.log_dict(
metrics,
)

def configure_optimizers(self):
"""Configure the optimizers using learning rate found with LR finder if used"""
if self.lr is not None:
Expand Down
Loading
Loading