|
1 | 1 | """Utils common between Wind and PV datamodules"""
|
2 |
| -import numpy as np |
3 |
| -import torch |
4 | 2 | from ocf_datapipes.batch import BatchKey, unstack_np_batch_into_examples
|
5 | 3 | from torch.utils.data import IterDataPipe, functional_datapipe
|
6 | 4 |
|
7 | 5 |
|
8 |
| -def copy_batch_to_device(batch, device): |
9 |
| - """Moves a dict-batch of tensors to new device.""" |
10 |
| - batch_copy = {} |
11 |
| - |
12 |
| - for k, v in batch.items(): |
13 |
| - if isinstance(v, dict): |
14 |
| - # Recursion to reach the nested NWP |
15 |
| - batch_copy[k] = copy_batch_to_device(v, device) |
16 |
| - elif isinstance(v, torch.Tensor): |
17 |
| - batch_copy[k] = v.to(device) |
18 |
| - else: |
19 |
| - batch_copy[k] = v |
20 |
| - return batch_copy |
21 |
| - |
22 |
| - |
23 |
| -def batch_to_tensor(batch): |
24 |
| - """Moves numpy batch to a tensor""" |
25 |
| - for k, v in batch.items(): |
26 |
| - if isinstance(v, dict): |
27 |
| - # Recursion to reach the nested NWP |
28 |
| - batch[k] = batch_to_tensor(v) |
29 |
| - elif isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number): |
30 |
| - batch[k] = torch.as_tensor(v) |
31 |
| - return batch |
32 |
| - |
33 |
| - |
34 | 6 | @functional_datapipe("split_batches")
|
35 | 7 | class BatchSplitter(IterDataPipe):
|
36 | 8 | """Pipeline step to split batches of data and yield single examples"""
|
|
0 commit comments