|
1 | 1 | """Utils common between Wind and PV datamodules""" |
2 | | -import numpy as np |
3 | | -import torch |
4 | 2 | from ocf_datapipes.batch import BatchKey, unstack_np_batch_into_examples |
5 | 3 | from torch.utils.data import IterDataPipe, functional_datapipe |
6 | 4 |
|
7 | 5 |
|
8 | | -def copy_batch_to_device(batch, device): |
9 | | - """Moves a dict-batch of tensors to new device.""" |
10 | | - batch_copy = {} |
11 | | - |
12 | | - for k, v in batch.items(): |
13 | | - if isinstance(v, dict): |
14 | | - # Recursion to reach the nested NWP |
15 | | - batch_copy[k] = copy_batch_to_device(v, device) |
16 | | - elif isinstance(v, torch.Tensor): |
17 | | - batch_copy[k] = v.to(device) |
18 | | - else: |
19 | | - batch_copy[k] = v |
20 | | - return batch_copy |
21 | | - |
22 | | - |
23 | | -def batch_to_tensor(batch): |
24 | | - """Moves numpy batch to a tensor""" |
25 | | - for k, v in batch.items(): |
26 | | - if isinstance(v, dict): |
27 | | - # Recursion to reach the nested NWP |
28 | | - batch[k] = batch_to_tensor(v) |
29 | | - elif isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number): |
30 | | - batch[k] = torch.as_tensor(v) |
31 | | - return batch |
32 | | - |
33 | | - |
34 | 6 | @functional_datapipe("split_batches") |
35 | 7 | class BatchSplitter(IterDataPipe): |
36 | 8 | """Pipeline step to split batches of data and yield single examples""" |
|
0 commit comments