From 7c1b8fc549fef4bbcf6fc22e720b555301768514 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 16:01:08 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet/data/datamodule.py | 15 +++++++++++---- pvnet/data/pv_site_datamodule.py | 10 ++++++++-- pvnet/data/wind_datamodule.py | 11 ++++++++--- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index be579a0f..4e2fb340 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -5,8 +5,8 @@ from ocf_datapipes.training.pvnet import pvnet_datapipe from torch.utils.data.datapipes.iter import FileLister -from pvnet.data.utils import batch_to_tensor from pvnet.data.base import BaseDataModule +from pvnet.data.utils import batch_to_tensor class DataModule(BaseDataModule): @@ -40,8 +40,16 @@ def __init__( `configuration` or 'train/val/test_period'. """ - super().__init__(configuration=configuration, batch_size=batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, train_period=train_period, val_period=val_period, test_period=test_period, batch_dir=batch_dir) - + super().__init__( + configuration=configuration, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + train_period=train_period, + val_period=val_period, + test_period=test_period, + batch_dir=batch_dir, + ) def _get_datapipe(self, start_time, end_time): data_pipeline = pvnet_datapipe( @@ -82,4 +90,3 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False): ) return data_pipeline - diff --git a/pvnet/data/pv_site_datamodule.py b/pvnet/data/pv_site_datamodule.py index 52669a53..4de606de 100644 --- a/pvnet/data/pv_site_datamodule.py +++ b/pvnet/data/pv_site_datamodule.py @@ -3,8 +3,8 @@ from ocf_datapipes.batch import BatchKey, stack_np_examples_into_batch from ocf_datapipes.training.pvnet_site import pvnet_site_netcdf_datapipe -from pvnet.data.base import BaseDataModule +from pvnet.data.base import BaseDataModule from pvnet.data.utils import batch_to_tensor @@ -33,7 +33,13 @@ def __init__( 'train/val/test_period'. """ - super().__init__(configuration=configuration, batch_size=batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, batch_dir=batch_dir) + super().__init__( + configuration=configuration, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + batch_dir=batch_dir, + ) def _get_datapipe(self, start_time, end_time): data_pipeline = pvnet_site_netcdf_datapipe( diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py index f6786e13..8b3988fc 100644 --- a/pvnet/data/wind_datamodule.py +++ b/pvnet/data/wind_datamodule.py @@ -3,8 +3,8 @@ from ocf_datapipes.batch import BatchKey, stack_np_examples_into_batch from ocf_datapipes.training.windnet import windnet_netcdf_datapipe -from pvnet.data.base import BaseDataModule +from pvnet.data.base import BaseDataModule from pvnet.data.utils import batch_to_tensor @@ -33,8 +33,13 @@ def __init__( 'train/val/test_period'. """ - super().__init__(configuration=configuration, batch_size=batch_size, num_workers=num_workers, - prefetch_factor=prefetch_factor, batch_dir=batch_dir) + super().__init__( + configuration=configuration, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + batch_dir=batch_dir, + ) def _get_datapipe(self, start_time, end_time): data_pipeline = windnet_netcdf_datapipe(