Skip to content

Commit 83fa44a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a85adca commit 83fa44a

File tree

14 files changed

+22
-24
lines changed

14 files changed

+22
-24
lines changed

pvnet/data/site_datamodule.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
from glob import glob
33

44
import xarray as xr
5-
from ocf_data_sampler.torch_datasets.datasets.site import SitesDataset, convert_netcdf_to_numpy_sample
5+
from ocf_data_sampler.torch_datasets.datasets.site import (
6+
SitesDataset,
7+
convert_netcdf_to_numpy_sample,
8+
)
69
from torch.utils.data import Dataset
710

811
from pvnet.data.base_datamodule import BaseDataModule

pvnet/data/uk_regional_datamodule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from glob import glob
33

44
import torch
5-
from pvnet.data.base_datamodule import BaseDataModule
65
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk_regional import PVNetUKRegionalDataset
76
from torch.utils.data import Dataset
87

8+
from pvnet.data.base_datamodule import BaseDataModule
9+
910

1011
class NumpybatchPremadeSamplesDataset(Dataset):
1112
"""Dataset to load NumpyBatch samples"""

pvnet/models/base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from huggingface_hub.constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME
1919
from huggingface_hub.file_download import hf_hub_download
2020
from huggingface_hub.hf_api import HfApi
21-
from ocf_datapipes.batch import BatchKey, copy_batch_to_device
21+
from ocf_datapipes.batch import copy_batch_to_device
2222

2323
from pvnet.models.utils import (
2424
BatchAccumulator,

pvnet/models/baseline/last_value.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Persistence model"""
22

3-
from ocf_datapipes.batch import BatchKey
43

54
import pvnet
65
from pvnet.models.base_model import BaseModel

pvnet/models/baseline/single_value.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Average value model"""
22
import torch
3-
from ocf_datapipes.batch import BatchKey
43
from torch import nn
54

65
import pvnet

pvnet/models/multimodal/multimodal.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Optional
55

66
import torch
7-
from ocf_datapipes.batch import BatchKey, NWPBatchKey
87
from omegaconf import DictConfig
98
from torch import nn
109

pvnet/models/multimodal/multimodal_base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Base model class for multimodal model and unimodal teacher"""
2-
from ocf_datapipes.batch import BatchKey, NWPBatchKey
32
from torchvision.transforms.functional import center_crop
43

54
from pvnet.models.base_model import BaseModel

pvnet/models/multimodal/site_encoders/encoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def _encode_inputs(self, x):
209209
# Shape: [batch size, sequence length, number of sites] -> [8, 197, 1]
210210
# Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23]
211211
input_data = x[f"{self.input_key_to_use}"]
212-
if len(input_data.shape) == 2: # one site per sample
213-
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
212+
if len(input_data.shape) == 2: # one site per sample
213+
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
214214
if len(input_data.shape) == 4: # Has multiple channels
215215
input_data = input_data[:, :, : self.sequence_length]
216216
input_data = einops.rearrange(input_data, "b id s c -> b (s c) id")

pvnet/models/multimodal/unimodal_teacher.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import hydra
88
import torch
99
import torch.nn.functional as F
10-
from ocf_datapipes.batch import BatchKey, NWPBatchKey
1110
from pyaml_env import parse_config
1211
from torch import nn
1312

pvnet/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import xarray as xr
1414
from lightning.pytorch.loggers import Logger
1515
from lightning.pytorch.utilities import rank_zero_only
16-
from ocf_datapipes.batch import BatchKey
1716
from ocf_datapipes.utils import Location
1817
from omegaconf import DictConfig, OmegaConf
1918

tests/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def sample_train_val_datamodule():
105105

106106
file_n = 0
107107

108-
for file_n, file in enumerate(glob.glob("tests/test_data/presaved_samples_uk_regional/train/*.pt")):
108+
for file_n, file in enumerate(
109+
glob.glob("tests/test_data/presaved_samples_uk_regional/train/*.pt")
110+
):
109111
sample = torch.load(file)
110112

111113
for i in range(n_duplicates):
@@ -204,14 +206,12 @@ def site_encoder_model_kwargs():
204206
)
205207
return kwargs
206208

209+
207210
@pytest.fixture()
208211
def site_encoder_model_kwargs_dsampler():
209212
# Used to test site encoder model on PV data
210213
kwargs = dict(
211-
sequence_length=60 // 15 - 1,
212-
num_sites=1,
213-
out_features=128,
214-
target_key_to_use="site"
214+
sequence_length=60 // 15 - 1, num_sites=1, out_features=128, target_key_to_use="site"
215215
)
216216
return kwargs
217217

tests/models/multimodal/site_encoders/test_encoders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,6 @@ def test_simplelearnedaggregator_backward(sample_pv_batch, site_encoder_model_kw
5858

5959

6060
def test_singleattentionnetwork_backward(sample_site_batch, site_encoder_model_kwargs_dsampler):
61-
_test_model_backward(sample_site_batch, SingleAttentionNetwork, site_encoder_model_kwargs_dsampler)
61+
_test_model_backward(
62+
sample_site_batch, SingleAttentionNetwork, site_encoder_model_kwargs_dsampler
63+
)

tests/test_data/presaved_samples_site/data_configuration.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ general:
33
name: windnet_data_sampler
44

55
input_data:
6-
76
site:
87
time_resolution_minutes: 15
98
interval_start_minutes: -60
@@ -13,7 +12,6 @@ input_data:
1312
dropout_timedeltas_minutes: null
1413
dropout_fraction: 0 # Fraction of samples with dropout
1514

16-
1715
# nwp:
1816
# mo_global:
1917
# provider: mo_global
@@ -35,7 +33,7 @@ input_data:
3533
# # How long after the NWP init-time are we still willing to use this forecast
3634
# # If null we use each init-time for all steps it covers
3735
# max_staleness_minutes: null
38-
36+
3937
nwp:
4038
ecmwf:
4139
provider: ecmwf

tests/test_data/presaved_samples_site/datamodule.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ sample_output_dir: /mnt/storage_ssd_4tb/site_batches_d_sampler_output_multiple_s
77
num_val_samples: 64
88
num_train_samples: 128
99
train_period:
10-
- '2022-03-01'
11-
- '2022-06-01'
10+
- "2022-03-01"
11+
- "2022-06-01"
1212
val_period:
13-
- '2022-06-02'
14-
- '2022-07-01'
13+
- "2022-06-02"
14+
- "2022-07-01"

0 commit comments

Comments
 (0)