From 1ad649dd9f552ce11a364168e29bb74247f06220 Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Tue, 6 Feb 2024 09:52:17 +0000 Subject: [PATCH] Linting fixes --- pvnet/data/pv_site_datamodule.py | 3 +++ pvnet/data/wind_datamodule.py | 3 +++ pvnet/models/base_model.py | 1 + pvnet/models/multimodal/multimodal.py | 2 ++ 4 files changed, 9 insertions(+) diff --git a/pvnet/data/pv_site_datamodule.py b/pvnet/data/pv_site_datamodule.py index 7fa6086f..7d49f8bd 100644 --- a/pvnet/data/pv_site_datamodule.py +++ b/pvnet/data/pv_site_datamodule.py @@ -31,6 +31,9 @@ def __init__( configuration: Path to datapipe configuration file. batch_size: Batch size. num_workers: Number of workers to use in multiprocess batch loading. + train_period: Date range filter for train dataloader. + val_period: Date range filter for val dataloader. + test_period: Date range filter for test dataloader. prefetch_factor: Number of data will be prefetched at the end of each worker process. batch_dir: Path to the directory of pre-saved batches. Cannot be used together with 'train/val/test_period'. diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py index 43d6a34b..eec3365b 100644 --- a/pvnet/data/wind_datamodule.py +++ b/pvnet/data/wind_datamodule.py @@ -30,6 +30,9 @@ def __init__( Args: configuration: Path to datapipe configuration file. batch_size: Batch size. + train_period: Date range filter for train dataloader. + val_period: Date range filter for val dataloader. + test_period: Date range filter for test dataloader. num_workers: Number of workers to use in multiprocess batch loading. prefetch_factor: Number of data will be prefetched at the end of each worker process. batch_dir: Path to the directory of pre-saved batches. Cannot be used together with diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index 33d37127..f61141c3 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -253,6 +253,7 @@ def __init__( None the output is a single value. target_key: The key of the target variable in the batch interval_minutes: The interval in minutes between each timestep in the data + timestep_intervals_to_plot: Intervals, in timesteps, to plot during training """ super().__init__() diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index 314745d1..ea0543d5 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -116,6 +116,8 @@ def __init__( sensor_interval_minutes: The interval between each sample of the sensor data image_embedding_dim: The number of dimensions to use for the image embedding timestep_intervals_to_plot: Intervals, in timesteps, to plot in addition to the full forecast + sensor_encoder: Encoder for sensor data + sensor_history_minutes: Length of recent sensor data used as input. """ self.include_gsp_yield_history = include_gsp_yield_history