Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 21, 2025
1 parent a85adca commit 83fa44a
Show file tree
Hide file tree
Showing 14 changed files with 22 additions and 24 deletions.
5 changes: 4 additions & 1 deletion pvnet/data/site_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from glob import glob

import xarray as xr
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_netcdf_to_numpy_sample
from ocf_data_sampler.torch_datasets.datasets.site import (
SitesDataset,
convert_netcdf_to_numpy_sample,
)
from torch.utils.data import Dataset

from pvnet.data.base_datamodule import BaseDataModule
Expand Down
3 changes: 2 additions & 1 deletion pvnet/data/uk_regional_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from glob import glob

import torch
from pvnet.data.base_datamodule import BaseDataModule
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
from torch.utils.data import Dataset

from pvnet.data.base_datamodule import BaseDataModule


class NumpybatchPremadeSamplesDataset(Dataset):
"""Dataset to load NumpyBatch samples"""
Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
from huggingface_hub.file_download import hf_hub_download
from huggingface_hub.hf_api import HfApi
from ocf_datapipes.batch import BatchKey, copy_batch_to_device
from ocf_datapipes.batch import copy_batch_to_device

from pvnet.models.utils import (
BatchAccumulator,
Expand Down
1 change: 0 additions & 1 deletion pvnet/models/baseline/last_value.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Persistence model"""

from ocf_datapipes.batch import BatchKey

import pvnet
from pvnet.models.base_model import BaseModel
Expand Down
1 change: 0 additions & 1 deletion pvnet/models/baseline/single_value.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Average value model"""
import torch
from ocf_datapipes.batch import BatchKey
from torch import nn

import pvnet
Expand Down
1 change: 0 additions & 1 deletion pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Optional

import torch
from ocf_datapipes.batch import BatchKey, NWPBatchKey
from omegaconf import DictConfig
from torch import nn

Expand Down
1 change: 0 additions & 1 deletion pvnet/models/multimodal/multimodal_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Base model class for multimodal model and unimodal teacher"""
from ocf_datapipes.batch import BatchKey, NWPBatchKey
from torchvision.transforms.functional import center_crop

from pvnet.models.base_model import BaseModel
Expand Down
4 changes: 2 additions & 2 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def _encode_inputs(self, x):
# Shape: [batch size, sequence length, number of sites] -> [8, 197, 1]
# Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23]
input_data = x[f"{self.input_key_to_use}"]
if len(input_data.shape) == 2: # one site per sample
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
if len(input_data.shape) == 2: # one site per sample
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
if len(input_data.shape) == 4: # Has multiple channels
input_data = input_data[:, :, : self.sequence_length]
input_data = einops.rearrange(input_data, "b id s c -> b (s c) id")
Expand Down
1 change: 0 additions & 1 deletion pvnet/models/multimodal/unimodal_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import hydra
import torch
import torch.nn.functional as F
from ocf_datapipes.batch import BatchKey, NWPBatchKey
from pyaml_env import parse_config
from torch import nn

Expand Down
1 change: 0 additions & 1 deletion pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import xarray as xr
from lightning.pytorch.loggers import Logger
from lightning.pytorch.utilities import rank_zero_only
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.utils import Location
from omegaconf import DictConfig, OmegaConf

Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def sample_train_val_datamodule():

file_n = 0

for file_n, file in enumerate(glob.glob("tests/test_data/presaved_samples_uk_regional/train/*.pt")):
for file_n, file in enumerate(
glob.glob("tests/test_data/presaved_samples_uk_regional/train/*.pt")
):
sample = torch.load(file)

for i in range(n_duplicates):
Expand Down Expand Up @@ -204,14 +206,12 @@ def site_encoder_model_kwargs():
)
return kwargs


@pytest.fixture()
def site_encoder_model_kwargs_dsampler():
# Used to test site encoder model on PV data
kwargs = dict(
sequence_length=60 // 15 - 1,
num_sites=1,
out_features=128,
target_key_to_use="site"
sequence_length=60 // 15 - 1, num_sites=1, out_features=128, target_key_to_use="site"
)
return kwargs

Expand Down
4 changes: 3 additions & 1 deletion tests/models/multimodal/site_encoders/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,6 @@ def test_simplelearnedaggregator_backward(sample_pv_batch, site_encoder_model_kw


def test_singleattentionnetwork_backward(sample_site_batch, site_encoder_model_kwargs_dsampler):
_test_model_backward(sample_site_batch, SingleAttentionNetwork, site_encoder_model_kwargs_dsampler)
_test_model_backward(
sample_site_batch, SingleAttentionNetwork, site_encoder_model_kwargs_dsampler
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ general:
name: windnet_data_sampler

input_data:

site:
time_resolution_minutes: 15
interval_start_minutes: -60
Expand All @@ -13,7 +12,6 @@ input_data:
dropout_timedeltas_minutes: null
dropout_fraction: 0 # Fraction of samples with dropout


# nwp:
# mo_global:
# provider: mo_global
Expand All @@ -35,7 +33,7 @@ input_data:
# # How long after the NWP init-time are we still willing to use this forecast
# # If null we use each init-time for all steps it covers
# max_staleness_minutes: null

nwp:
ecmwf:
provider: ecmwf
Expand Down
8 changes: 4 additions & 4 deletions tests/test_data/presaved_samples_site/datamodule.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ sample_output_dir: /mnt/storage_ssd_4tb/site_batches_d_sampler_output_multiple_s
num_val_samples: 64
num_train_samples: 128
train_period:
- '2022-03-01'
- '2022-06-01'
- "2022-03-01"
- "2022-06-01"
val_period:
- '2022-06-02'
- '2022-07-01'
- "2022-06-02"
- "2022-07-01"

0 comments on commit 83fa44a

Please sign in to comment.