Skip to content

Commit

Permalink
copy_batch_to_device functionality (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
alirashidAR authored Feb 18, 2025
1 parent 6dbd5fa commit 0829f92
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
23 changes: 23 additions & 0 deletions ocf_data_sampler/sample/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,26 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
elif np.issubdtype(v.dtype, np.number):
batch[k] = torch.as_tensor(v)
return batch


def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
"""
Moves tensor leaves in a nested dict to a new device.
Args:
batch: Nested dict with tensors to move.
device: Device to move tensors to.
Returns:
A dict with tensors moved to the new device.
"""
batch_copy = {}

for k, v in batch.items():
if isinstance(v, dict):
batch_copy[k] = copy_batch_to_device(v, device)
elif isinstance(v, torch.Tensor):
batch_copy[k] = v.to(device)
else:
batch_copy[k] = v
return batch_copy
2 changes: 1 addition & 1 deletion scripts/refactor_site.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
""" Helper functions for refactoring legacy site data """

import xarray as xr

def legacy_format(data_ds, metadata_df):
"""This formats old legacy data to the new format.
Expand Down
19 changes: 18 additions & 1 deletion tests/test_sample/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from pathlib import Path
from ocf_data_sampler.sample.base import (
SampleBase,
batch_to_tensor
batch_to_tensor,
copy_batch_to_device
)

class TestSample(SampleBase):
Expand Down Expand Up @@ -145,3 +146,19 @@ def test_batch_to_tensor_multidimensional():
assert tensor_batch['matrix'].shape == (2, 2)
assert tensor_batch['tensor'].shape == (2, 2, 2)
assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]]))


def test_copy_batch_to_device():
""" Test moving tensors to a different device """
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch = {
'tensor_data': torch.tensor([1, 2, 3]),
'nested': {
'matrix': torch.tensor([[1, 2], [3, 4]])
},
'non_tensor': 'unchanged'
}
moved_batch = copy_batch_to_device(batch, device)
assert moved_batch['tensor_data'].device == device
assert moved_batch['nested']['matrix'].device == device
assert moved_batch['non_tensor'] == 'unchanged' # Non-tensors should remain unchanged

0 comments on commit 0829f92

Please sign in to comment.