From 27185a9b13f48813d57b219cd2be515d498397bc Mon Sep 17 00:00:00 2001 From: alirashidAR Date: Sun, 16 Feb 2025 01:35:35 +0530 Subject: [PATCH 1/3] feat: copy_batch__to_device function added with test --- ocf_data_sampler/sample/base.py | 24 ++++++++++++++++++++++++ tests/test_sample/test_base.py | 19 ++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/ocf_data_sampler/sample/base.py b/ocf_data_sampler/sample/base.py index ecbcdf44..491827ec 100644 --- a/ocf_data_sampler/sample/base.py +++ b/ocf_data_sampler/sample/base.py @@ -73,3 +73,27 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch: elif np.issubdtype(v.dtype, np.number): batch[k] = torch.as_tensor(v) return batch + +import torch + +def copy_batch_to_device(batch: dict, device: torch.device) -> dict: + """ + Moves tensor leaves in a nested dict to a new device. + + Args: + batch: Nested dict with tensors to move. + device: Device to move tensors to. + + Returns: + A dict with tensors moved to the new device. + """ + batch_copy = {} + + for k, v in batch.items(): + if isinstance(v, dict): + batch_copy[k] = copy_batch_to_device(v, device) + elif isinstance(v, torch.Tensor): + batch_copy[k] = v.to(device) + else: + batch_copy[k] = v + return batch_copy diff --git a/tests/test_sample/test_base.py b/tests/test_sample/test_base.py index b7618ce1..6486929f 100644 --- a/tests/test_sample/test_base.py +++ b/tests/test_sample/test_base.py @@ -9,7 +9,8 @@ from pathlib import Path from ocf_data_sampler.sample.base import ( SampleBase, - batch_to_tensor + batch_to_tensor, + copy_batch_to_device ) class TestSample(SampleBase): @@ -145,3 +146,19 @@ def test_batch_to_tensor_multidimensional(): assert tensor_batch['matrix'].shape == (2, 2) assert tensor_batch['tensor'].shape == (2, 2, 2) assert torch.equal(tensor_batch['matrix'], torch.tensor([[1, 2], [3, 4]])) + + +def test_copy_batch_to_device(): + """ Test moving tensors to a different device """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + batch = { + 'tensor_data': torch.tensor([1, 2, 3]), + 'nested': { + 'matrix': torch.tensor([[1, 2], [3, 4]]) + }, + 'non_tensor': 'unchanged' + } + moved_batch = copy_batch_to_device(batch, device) + assert moved_batch['tensor_data'].device == device + assert moved_batch['nested']['matrix'].device == device + assert moved_batch['non_tensor'] == 'unchanged' # Non-tensors should remain unchanged \ No newline at end of file From 87133768ffadc9832185b91bfe499cb761318a6a Mon Sep 17 00:00:00 2001 From: alirashidAR Date: Sun, 16 Feb 2025 01:42:32 +0530 Subject: [PATCH 2/3] fix: xarray import --- scripts/refactor_site.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/refactor_site.py b/scripts/refactor_site.py index b3bf5ea9..15abb24b 100644 --- a/scripts/refactor_site.py +++ b/scripts/refactor_site.py @@ -1,5 +1,5 @@ """ Helper functions for refactoring legacy site data """ - +import xarray as xr def legacy_format(data_ds, metadata_df): """This formats old legacy data to the new format. From 7036d0660677c5bcd3dad1da8ca51ce4093e03e0 Mon Sep 17 00:00:00 2001 From: alirashidAR Date: Sun, 16 Feb 2025 01:50:44 +0530 Subject: [PATCH 3/3] fix: redundant import --- ocf_data_sampler/sample/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ocf_data_sampler/sample/base.py b/ocf_data_sampler/sample/base.py index 491827ec..c5b44a37 100644 --- a/ocf_data_sampler/sample/base.py +++ b/ocf_data_sampler/sample/base.py @@ -74,7 +74,6 @@ def batch_to_tensor(batch: NumpyBatch) -> TensorBatch: batch[k] = torch.as_tensor(v) return batch -import torch def copy_batch_to_device(batch: dict, device: torch.device) -> dict: """