Skip to content

Commit d5fc458

Browse files
authored
Merge pull request #170 from markus-kreft/iss111
Use batch transformation functions from ocf_datapipes
2 parents 6d6dba7 + b255253 commit d5fc458

7 files changed

+6
-39
lines changed

pvnet/data/datamodule.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import resource
44

55
import torch
6-
from ocf_datapipes.batch import stack_np_examples_into_batch
6+
from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch
77
from ocf_datapipes.training.pvnet import pvnet_datapipe
88
from torch.utils.data.datapipes.iter import FileLister
99

1010
from pvnet.data.base import BaseDataModule
11-
from pvnet.data.utils import batch_to_tensor
1211

1312
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
1413
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, rlimit[1]))

pvnet/data/pv_site_datamodule.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
""" Data module for pytorch lightning """
22
import glob
33

4-
from ocf_datapipes.batch import BatchKey, stack_np_examples_into_batch
4+
from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch
55
from ocf_datapipes.training.pvnet_site import pvnet_site_netcdf_datapipe
66

77
from pvnet.data.base import BaseDataModule
8-
from pvnet.data.utils import batch_to_tensor
98

109

1110
class PVSiteDataModule(BaseDataModule):

pvnet/data/utils.py

-28
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,8 @@
11
"""Utils common between Wind and PV datamodules"""
2-
import numpy as np
3-
import torch
42
from ocf_datapipes.batch import BatchKey, unstack_np_batch_into_examples
53
from torch.utils.data import IterDataPipe, functional_datapipe
64

75

8-
def copy_batch_to_device(batch, device):
9-
"""Moves a dict-batch of tensors to new device."""
10-
batch_copy = {}
11-
12-
for k, v in batch.items():
13-
if isinstance(v, dict):
14-
# Recursion to reach the nested NWP
15-
batch_copy[k] = copy_batch_to_device(v, device)
16-
elif isinstance(v, torch.Tensor):
17-
batch_copy[k] = v.to(device)
18-
else:
19-
batch_copy[k] = v
20-
return batch_copy
21-
22-
23-
def batch_to_tensor(batch):
24-
"""Moves numpy batch to a tensor"""
25-
for k, v in batch.items():
26-
if isinstance(v, dict):
27-
# Recursion to reach the nested NWP
28-
batch[k] = batch_to_tensor(v)
29-
elif isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
30-
batch[k] = torch.as_tensor(v)
31-
return batch
32-
33-
346
@functional_datapipe("split_batches")
357
class BatchSplitter(IterDataPipe):
368
"""Pipeline step to split batches of data and yield single examples"""

pvnet/data/wind_datamodule.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
""" Data module for pytorch lightning """
22
import glob
33

4-
from ocf_datapipes.batch import BatchKey, stack_np_examples_into_batch
4+
from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch
55
from ocf_datapipes.training.windnet import windnet_netcdf_datapipe
66

77
from pvnet.data.base import BaseDataModule
8-
from pvnet.data.utils import batch_to_tensor
98

109

1110
class WindDataModule(BaseDataModule):

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ocf_datapipes>=3.1.5
1+
ocf_datapipes>=3.3.6
22
ocf_ml_metrics>=0.0.11
33
numpy
44
pandas

scripts/save_batches.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import hydra
3535
import torch
36-
from ocf_datapipes.batch import stack_np_examples_into_batch
36+
from ocf_datapipes.batch import batch_to_tensor, stack_np_examples_into_batch
3737
from ocf_datapipes.training.pvnet import pvnet_datapipe
3838
from ocf_datapipes.training.pvnet_site import pvnet_site_datapipe
3939
from ocf_datapipes.training.windnet import windnet_datapipe
@@ -43,7 +43,6 @@
4343
from torch.utils.data.datapipes.iter import IterableWrapper
4444
from tqdm import tqdm
4545

46-
from pvnet.data.utils import batch_to_tensor
4746
from pvnet.utils import print_config
4847

4948
warnings.filterwarnings("ignore", category=sa_exc.SAWarning)

scripts/save_concurrent_batches.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import hydra
2424
import numpy as np
2525
import torch
26-
from ocf_datapipes.batch import BatchKey, stack_np_examples_into_batch
26+
from ocf_datapipes.batch import BatchKey, batch_to_tensor, stack_np_examples_into_batch
2727
from ocf_datapipes.training.common import (
2828
open_and_return_datapipes,
2929
)
@@ -34,7 +34,6 @@
3434
from torch.utils.data.datapipes.iter import IterableWrapper
3535
from tqdm import tqdm
3636

37-
from pvnet.data.utils import batch_to_tensor
3837
from pvnet.utils import GSPLocationLookup
3938

4039
warnings.filterwarnings("ignore", category=sa_exc.SAWarning)

0 commit comments

Comments
 (0)