diff --git a/monai/apps/reconstruction/transforms/dictionary.py b/monai/apps/reconstruction/transforms/dictionary.py index c166740768..518049f8e0 100644 --- a/monai/apps/reconstruction/transforms/dictionary.py +++ b/monai/apps/reconstruction/transforms/dictionary.py @@ -12,6 +12,7 @@ from __future__ import annotations from collections.abc import Hashable, Mapping, Sequence +from typing import Any import numpy as np from numpy import ndarray @@ -20,6 +21,7 @@ from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor +from monai.data import MetaTensor from monai.transforms import InvertibleTransform from monai.transforms.croppad.array import SpatialCrop from monai.transforms.intensity.array import NormalizeIntensity @@ -57,15 +59,26 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T Returns: the new data dictionary """ + d = dict(data) + + meta: dict[str, Any] + if isinstance(d[self.meta_key], MetaTensor): + # meta tensor + meta = d[self.meta_key].meta # type: ignore + else: + # meta dict + meta = d[self.meta_key] + for key in self.keys: - if key in d[self.meta_key]: - d[key] = d[self.meta_key][key] # type: ignore + if key in meta: + d[key] = meta[key] # type: ignore elif not self.allow_missing_keys: raise KeyError( f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data" " and allow_missing_keys==False." ) + return d # type: ignore diff --git a/tests/apps/reconstruction/transforms/test_extract_data_key_from_meta_keyd.py b/tests/apps/reconstruction/transforms/test_extract_data_key_from_meta_keyd.py new file mode 100644 index 0000000000..56d62bab3f --- /dev/null +++ b/tests/apps/reconstruction/transforms/test_extract_data_key_from_meta_keyd.py @@ -0,0 +1,39 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from monai.apps.reconstruction.transforms.dictionary import ExtractDataKeyFromMetaKeyd +from monai.data import MetaTensor + + +class TestExtractDataKeyFromMetaKeyd(unittest.TestCase): + def test_extract_data_key_from_dic(self): + data = {"image_data": MetaTensor([1, 2, 3]), "foo_meta_dict": {"filename_or_obj": "test_image.nii.gz"}} + + extract = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="foo_meta_dict") + result = extract(data) + + assert data["foo_meta_dict"]["filename_or_obj"] == result["filename_or_obj"] + + def test_extract_data_key_from_meta_tensor(self): + data = {"image_data": MetaTensor([1, 2, 3], meta={"filename_or_obj": 1})} + + extract = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image_data") + result = extract(data) + + assert data["image_data"].meta["filename_or_obj"] == result["filename_or_obj"] + + +if __name__ == "__main__": + unittest.main()