From eeec0beb15866d8175560b1c834ee24723b6ad5d Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Sat, 5 Apr 2025 23:46:27 +0100 Subject: [PATCH 01/12] Preliminary fix for discovered issue Signed-off-by: Eric Kerfoot --- monai/transforms/inverse.py | 58 +++++++- tests/transforms/inverse/test_inverse.py | 178 +++++++++++++++++++++++ 2 files changed, 229 insertions(+), 7 deletions(-) create mode 100644 tests/transforms/inverse/test_inverse.py diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index f94f11eca9..46a3a670c7 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -15,6 +15,7 @@ from collections.abc import Hashable, Mapping from contextlib import contextmanager from typing import Any +import threading import torch @@ -70,11 +71,35 @@ class TraceableTransform(Transform): `MONAI_TRACE_TRANSFORM` when initializing the class. """ - tracing = MONAIEnvVars.trace_transform() != "0" + # tracing = MONAIEnvVars.trace_transform() != "0" - def set_tracing(self, tracing: bool) -> None: - """Set whether to trace transforms.""" - self.tracing = tracing + # def set_tracing(self, tracing: bool) -> None: + # """Set whether to trace transforms.""" + # self.tracing = tracing + + def _init_trace_threadlocal(self): + # needed since this class is meant to be a trait with no constructor + if not hasattr(self, "_tracing"): + self._tracing = threading.local() + + # This is True while the above initialising _tracing is False when this is + # called from a different thread than the one initialising _tracing. + if not hasattr(self._tracing, "value"): + self._tracing.value = MONAIEnvVars.trace_transform() != "0" + + @property + def tracing(self) -> bool: + """ + Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != "0"`. + """ + self._init_trace_threadlocal() + return self._tracing.value + + @tracing.setter + def tracing(self, val: bool): + """Sets the thread-local tracing state to `val`.""" + self._init_trace_threadlocal() + self._tracing.value = val @staticmethod def trace_key(key: Hashable = None): @@ -291,7 +316,7 @@ def check_transforms_match(self, transform: Mapping) -> None: def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False): """ - Get most recent transform for the stack. + Get most recent matching transform for the current class from the sequence of applied operations. Args: data: dictionary of data or `MetaTensor`. @@ -316,9 +341,28 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr all_transforms = data.get(self.trace_key(key), MetaTensor.get_default_applied_operations()) else: raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.") + + # Find the last transform whose name matches that of this class, this allows Invertd to ignore applied + # operations added by transforms it is not trying to invert, ie. those added in postprocessing. + idx=-1 + # for i in reversed(range(len(all_transforms))): + # xform_name = all_transforms[i].get(TraceKeys.CLASS_NAME, "") + # if xform_name == self.__class__.__name__: + # idx=i # if nothing found, idx remains -1 so replicating previous behaviour + # break + + # print(f"get_most_recent_transform {id(data):x} {type(data).__name__} {pop} {id(all_transforms):x} {len(all_transforms)}") + + if not all_transforms: + raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'") + if check: - self.check_transforms_match(all_transforms[-1]) - return all_transforms.pop() if pop else all_transforms[-1] + if not (-len(all_transforms)<=idx 0) + + # post = self.postprocessing(pre) + + # self.assertTupleEqual(post[self.key].shape, (1, *self.orig_size)) + + + + # @parameterized.expand(product(sum(TEST_DEVICES,[]),[True, False])) + # def test_dataset_dataloader(self, device,use_threads): + # batch_size=2 + # dl_type=ThreadDataLoader if use_threads else DataLoader + + # ds = Dataset([{self.key: self.img.to(device)} for _ in range(20)], transform=self.preprocessing) + + # self.assertGreater(len(ds[0][self.key].applied_operations), 0, "Applied operations are missing") + + # dl = dl_type(ds,num_workers=0, batch_size=batch_size) + + # batch=first(dl) + + # self.assertEqual(len(batch[self.key].applied_operations), batch_size) + # self.assertGreater(len(batch[self.key].applied_operations[0]), 0, "Applied operations are missing") + + # # batch[CommonKeys.PRED] = batch[self.key] + # # post_batch=engine_apply_transform(batch=batch,output={},transform=self.postprocessing) + + + + @parameterized.expand(TEST_DEVICES) + def test_workflow(self, device): + test_data = [{self.key: self.img.clone().to(device)} for _ in range(4)] + batch_size=2 + ds = Dataset(test_data, transform=self.preprocessing) + dl = ThreadDataLoader(ds, num_workers=0, batch_size=batch_size) + # dl = DataLoader(ds,num_workers=0, batch_size=batch_size) + + class AssertAppliedOps(torch.nn.Module): + def forward(self,x): + assert len(x.applied_operations)==x.shape[0] + assert all(len(a)>0 for a in x.applied_operations) + return x + + # def _print(x): + # print(type(x), id(x), x.shape, len(x.applied_operations)) + # del x.applied_operations[:] + # return x + + # postprocessing = Compose([ + # Lambdad(self.key, func=_print), + # ]) + + + + evaluator = SupervisedEvaluator( + device=device, + network=AssertAppliedOps(), + postprocessing=self.postprocessing, + val_data_loader=dl + ) + + # def tensor_struct_info(tstruct): + # if isinstance(tstruct, torch.Tensor): + # return f"{id(tstruct):x} {tuple(tstruct.shape)} {tstruct.dtype} {len(getattr(tstruct,"applied_operations",[]))}" + # elif isinstance(tstruct, Sequence): + # return list(map(tensor_struct_info, tstruct)) + # elif isinstance(tstruct, Mapping): + # return {k: tensor_struct_info(v) for k, v in tstruct.items()} + # else: + # return repr(tstruct) + + # @evaluator.on(IterationEvents.MODEL_COMPLETED) + # def _run_postprocessing(engine:SupervisedEvaluator) -> None: + # print("\n===================\n") + # # print("Batch:",dumps(tensor_struct_info(engine.state.batch),indent=2),flush=True) + # print("Output:",dumps(tensor_struct_info(engine.state.output),indent=2),flush=True) + + # for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): + # # print("Post:",dumps(tensor_struct_info(o),indent=2),flush=True) + # engine.state.batch[i], engine.state.output[i] = engine_apply_transform(b, o, self.postprocessing) + + # # evaluator._register_postprocessing(self.postprocessing) + + evaluator.run() + + # self.assertTrue(len(evaluator.state.batch[0][self.key].applied_operations)>0) + + + +if __name__ == "__main__": + unittest.main() From 11be4cb5cb3b810d92d2fc6329fff96a990abe13 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 8 Apr 2025 18:37:42 +0100 Subject: [PATCH 02/12] Updates to tests and moving items around --- monai/transforms/inverse.py | 20 +- tests/transforms/inverse/test_inverse.py | 619 ++++++++++++++---- tests/transforms/inverse/test_inverse_dict.py | 111 ++++ tests/transforms/{ => inverse}/test_invert.py | 0 .../transforms/{ => inverse}/test_invertd.py | 0 tests/transforms/test_inverse.py | 521 --------------- 6 files changed, 598 insertions(+), 673 deletions(-) create mode 100644 tests/transforms/inverse/test_inverse_dict.py rename tests/transforms/{ => inverse}/test_invert.py (100%) rename tests/transforms/{ => inverse}/test_invertd.py (100%) delete mode 100644 tests/transforms/test_inverse.py diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 46a3a670c7..a47a147b23 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -67,16 +67,10 @@ class TraceableTransform(Transform): The information in the stack of applied transforms must be compatible with the default collate, by only storing strings, numbers and arrays. - `tracing` could be enabled by `self.set_tracing` or setting + `tracing` could be enabled by assigning to `self.tracing` or setting `MONAI_TRACE_TRANSFORM` when initializing the class. """ - # tracing = MONAIEnvVars.trace_transform() != "0" - - # def set_tracing(self, tracing: bool) -> None: - # """Set whether to trace transforms.""" - # self.tracing = tracing - def _init_trace_threadlocal(self): # needed since this class is meant to be a trait with no constructor if not hasattr(self, "_tracing"): @@ -345,13 +339,11 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr # Find the last transform whose name matches that of this class, this allows Invertd to ignore applied # operations added by transforms it is not trying to invert, ie. those added in postprocessing. idx=-1 - # for i in reversed(range(len(all_transforms))): - # xform_name = all_transforms[i].get(TraceKeys.CLASS_NAME, "") - # if xform_name == self.__class__.__name__: - # idx=i # if nothing found, idx remains -1 so replicating previous behaviour - # break - - # print(f"get_most_recent_transform {id(data):x} {type(data).__name__} {pop} {id(all_transforms):x} {len(all_transforms)}") + for i in reversed(range(len(all_transforms))): + xform_name = all_transforms[i].get(TraceKeys.CLASS_NAME, "") + if xform_name == self.__class__.__name__: + idx=i # if nothing found, idx remains -1 so replicating previous behaviour + break if not all_transforms: raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'") diff --git a/tests/transforms/inverse/test_inverse.py b/tests/transforms/inverse/test_inverse.py index b882ff0ad1..01d32e4baf 100644 --- a/tests/transforms/inverse/test_inverse.py +++ b/tests/transforms/inverse/test_inverse.py @@ -11,167 +11,510 @@ from __future__ import annotations -from json import dumps -from itertools import product -import time -from typing import Mapping, Sequence +import random +import sys import unittest +from copy import deepcopy +from functools import partial +from typing import TYPE_CHECKING +from unittest.case import skipUnless +import numpy as np import torch from parameterized import parameterized -from monai.data import MetaTensor, create_test_image_2d, Dataset, ThreadDataLoader, DataLoader -from monai.engines.evaluator import SupervisedEvaluator -from monai.engines.utils import IterationEvents, engine_apply_transform -from monai.transforms import Compose, EnsureChannelFirstd, Resized, Transposed, Invertd, Spacingd -from monai.transforms.utility.array import SimulateDelay -from monai.transforms.utility.dictionary import Lambdad -from monai.utils.misc import first -from monai.utils.enums import CommonKeys +from monai.data import CacheDataset, DataLoader, MetaTensor, create_test_image_2d, create_test_image_3d, decollate_batch +from monai.networks.nets import UNet +from monai.transforms import ( + Affined, + BorderPadd, + CenterScaleCropd, + CenterSpatialCropd, + Compose, + CropForegroundd, + DivisiblePadd, + EnsureChannelFirstd, + Flipd, + FromMetaTensord, + InvertibleTransform, + Lambdad, + LoadImaged, + Orientationd, + RandAffined, + RandAxisFlipd, + RandCropByLabelClassesd, + RandCropByPosNegLabeld, + RandFlipd, + RandLambdad, + Randomizable, + RandRotate90d, + RandRotated, + RandSpatialCropd, + RandSpatialCropSamplesd, + RandWeightedCropd, + RandZoomd, + Resized, + ResizeWithPadOrCrop, + ResizeWithPadOrCropd, + Rotate90d, + Rotated, + Spacingd, + SpatialCropd, + SpatialPadd, + ToMetaTensord, + Transposed, + Zoomd, + allow_missing_keys_mode, + convert_applied_interp_mode, + reset_ops_id, +) +from monai.utils import first, get_seed, optional_import, set_determinism +from tests.test_utils import make_nifti_image, make_rand_affine + +if TYPE_CHECKING: + has_nib = True +else: + _, has_nib = optional_import("nibabel") + +KEYS = ["image", "label"] + +TESTS: list[tuple] = [] + +# For pad, start with odd/even images and add odd/even amounts +for name in ("1D even", "1D odd"): + for val in (3, 4): + for t in ( + partial(SpatialPadd, spatial_size=val, method="symmetric"), + partial(SpatialPadd, spatial_size=val, method="end"), + partial(BorderPadd, spatial_border=[val, val + 1]), + partial(DivisiblePadd, k=val), + partial(ResizeWithPadOrCropd, spatial_size=20 + val), + partial(CenterSpatialCropd, roi_size=10 + val), + partial(CenterScaleCropd, roi_scale=0.8), + partial(CropForegroundd, source_key="label"), + partial(SpatialCropd, roi_center=10, roi_size=10 + val), + partial(SpatialCropd, roi_center=11, roi_size=10 + val), + partial(SpatialCropd, roi_start=val, roi_end=17), + partial(SpatialCropd, roi_start=val, roi_end=16), + partial(RandSpatialCropd, roi_size=12 + val), + partial(ResizeWithPadOrCropd, spatial_size=21 - val), + ): + TESTS.append((t.func.__name__ + name, name, 0, True, t(KEYS))) # type: ignore + +# non-sensical tests: crop bigger or pad smaller or -ve values +for t in ( + partial(DivisiblePadd, k=-3), + partial(CenterSpatialCropd, roi_size=-3), + partial(RandSpatialCropd, roi_size=-3), + partial(SpatialPadd, spatial_size=15), + partial(BorderPadd, spatial_border=[15, 16]), + partial(CenterSpatialCropd, roi_size=30), + partial(SpatialCropd, roi_center=10, roi_size=100), + partial(SpatialCropd, roi_start=3, roi_end=100), +): + TESTS.append((t.func.__name__ + "bad 1D even", "1D even", 0, True, t(KEYS))) # type: ignore + +TESTS.append( + ( + "SpatialPadd (x2) 2d", + "2D", + 0, + True, + SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), + SpatialPadd(KEYS, spatial_size=[118, 117]), + ) +) + +TESTS.append(("SpatialPadd 3d", "3D", 0, True, SpatialPadd(KEYS, spatial_size=[112, 113, 116]))) + +TESTS.append(("SpatialCropd 2d", "2D", 0, True, SpatialCropd(KEYS, [49, 51], [90, 89]))) + +TESTS.append( + ( + "SpatialCropd 3d", + "3D", + 0, + True, + SpatialCropd(KEYS, roi_slices=[slice(s, e) for s, e in zip([None, None, -99], [None, -2, None])]), + ) +) + +TESTS.append(("SpatialCropd 2d", "2D", 0, True, SpatialCropd(KEYS, [49, 51], [390, 89]))) + +TESTS.append(("SpatialCropd 3d", "3D", 0, True, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]))) + +TESTS.append(("RandSpatialCropd 2d", "2D", 0, True, RandSpatialCropd(KEYS, [96, 93], None, True, False))) + +TESTS.append(("RandSpatialCropd 3d", "3D", 0, True, RandSpatialCropd(KEYS, [96, 93, 92], None, False, False))) + +TESTS.append(("BorderPadd 2d", "2D", 0, True, BorderPadd(KEYS, [3, 7, 2, 5]))) + +TESTS.append(("BorderPadd 2d", "2D", 0, True, BorderPadd(KEYS, [3, 7]))) + +TESTS.append(("BorderPadd 3d", "3D", 0, True, BorderPadd(KEYS, [4]))) + +TESTS.append(("DivisiblePadd 2d", "2D", 0, True, DivisiblePadd(KEYS, k=4))) + +TESTS.append(("DivisiblePadd 3d", "3D", 0, True, DivisiblePadd(KEYS, k=[4, 8, 11]))) + +TESTS.append(("CenterSpatialCropd 2d", "2D", 0, True, CenterSpatialCropd(KEYS, roi_size=95))) + +TESTS.append(("CenterSpatialCropd 3d", "3D", 0, True, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]))) + +TESTS.append(("CropForegroundd 2d", "2D", 0, True, CropForegroundd(KEYS, source_key="label", margin=2))) + +TESTS.append(("CropForegroundd 3d", "3D", 0, True, CropForegroundd(KEYS, source_key="label", k_divisible=[5, 101, 2]))) + +TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, True, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) + +TESTS.append(("Flipd 3d", "3D", 0, True, Flipd(KEYS, [1, 2]))) +TESTS.append(("Flipd 3d", "3D", 0, True, Flipd(KEYS, [1, 2]))) + +TESTS.append(("RandFlipd 3d", "3D", 0, True, RandFlipd(KEYS, 1, [1, 2]))) + +TESTS.append(("RandAxisFlipd 3d", "3D", 0, True, RandAxisFlipd(KEYS, 1))) +TESTS.append(("RandAxisFlipd 3d", "3D", 0, True, RandAxisFlipd(KEYS, 1))) + +for acc in [True, False]: + TESTS.append(("Orientationd 3d", "3D", 0, True, Orientationd(KEYS, "RAS", as_closest_canonical=acc))) + +TESTS.append(("Rotate90d 2d", "2D", 0, True, Rotate90d(KEYS))) + +TESTS.append(("Rotate90d 3d", "3D", 0, True, Rotate90d(KEYS, k=2, spatial_axes=(1, 2)))) + +TESTS.append(("RandRotate90d 3d", "3D", 0, True, RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)))) + +TESTS.append(("Spacingd 3d", "3D", 3e-2, True, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) + +TESTS.append(("Resized 2d", "2D", 2e-1, True, Resized(KEYS, [50, 47]))) + +TESTS.append(("Resized 3d", "3D", 5e-2, True, Resized(KEYS, [201, 150, 78]))) + +TESTS.append(("Resized longest 2d", "2D", 2e-1, True, Resized(KEYS, 47, "longest", "area"))) + +TESTS.append(("Resized longest 3d", "3D", 5e-2, True, Resized(KEYS, 201, "longest", "trilinear", True))) + +TESTS.append( + ("Lambdad 2d", "2D", 5e-2, False, Lambdad(KEYS, func=lambda x: x + 5, inv_func=lambda x: x - 5, overwrite=True)) +) + +TESTS.append( + ( + "RandLambdad 3d", + "3D", + 5e-2, + False, + RandLambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True, prob=0.5), + ) +) -# from tests.test_utils import TEST_DEVICES -TEST_DEVICES = [[torch.device("cpu")]] +TESTS.append(("Zoomd 1d", "1D odd", 0, True, Zoomd(KEYS, zoom=2, keep_size=False))) +TESTS.append(("Zoomd 2d", "2D", 2e-1, True, Zoomd(KEYS, zoom=0.9))) -class TestInvertDict(unittest.TestCase): - - def setUp(self): - self.orig_size = (60, 60) - img, seg = create_test_image_2d(*self.orig_size, 2, 10, num_seg_classes=2) - self.img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0]}) - # self.seg = MetaTensor(seg, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0]}) - self.key = CommonKeys.IMAGE - self.new_pixdim = 2.0 - self.new_size = (55, 70) - - - def _print(x): - # print("PRE",f"{id(x):x}",type(x), x.shape, len(x.applied_operations),flush=True) - # time.sleep(0.01) - print("PRE",f"{id(x):x}",type(x).__name__, x.shape, len(x.applied_operations),flush=True) - return x - - self.preprocessing = Compose( - [ - EnsureChannelFirstd(self.key), - # Resized(self.key, self.new_size), - Spacingd(self.key, pixdim=[self.new_pixdim] * 2), - # Transposed(self.key, (0, 2, 1)), - Lambdad(self.key, func=_print) - ] - ) - - - self.postprocessing = Compose([ - # Lambdad(self.key, func=_print), - Invertd(CommonKeys.PRED, transform=self.preprocessing, orig_keys=self.key) - ]) - - # @parameterized.expand(TEST_DEVICES) - # def test_dataloader_read(self, device): - # test_data = [{self.key: self.img.clone().to(device)} for _ in range(4)] - # ds = Dataset(test_data, transform=self.preprocessing) - # dl = ThreadDataLoader(ds, num_workers=0, batch_size=2) - # # dl = DataLoader(ds,num_workers=0, batch_size=2) - - # alldata=list(dl) - - # @parameterized.expand(TEST_DEVICES) - # def test_simple_processing(self, device): - # item = {self.key: self.img.to(device)} - # pre = self.preprocessing(item) - - # nw = int(self.orig_size[0] / self.new_pixdim) - # nh = int(self.orig_size[1] / self.new_pixdim) - - # self.assertTupleEqual(pre[self.key].shape, (1, nh, nw)) - # self.assertTrue(len(pre[self.key].applied_operations) > 0) - - # post = self.postprocessing(pre) - - # self.assertTupleEqual(post[self.key].shape, (1, *self.orig_size)) - +TESTS.append(("Zoomd 3d", "3D", 3e-2, True, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False))) +TESTS.append(("RandZoom 3d", "3D", 9e-2, True, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) - # @parameterized.expand(product(sum(TEST_DEVICES,[]),[True, False])) - # def test_dataset_dataloader(self, device,use_threads): - # batch_size=2 - # dl_type=ThreadDataLoader if use_threads else DataLoader +TESTS.append(("RandRotated, prob 0", "2D", 0, True, RandRotated(KEYS, prob=0, dtype=np.float64))) - # ds = Dataset([{self.key: self.img.to(device)} for _ in range(20)], transform=self.preprocessing) +TESTS.append( + ( + "Rotated 2d", + "2D", + 8e-2, + True, + Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False, dtype=np.float64), + ) +) - # self.assertGreater(len(ds[0][self.key].applied_operations), 0, "Applied operations are missing") +TESTS.append( + ( + "Rotated 3d", + "3D", + 1e-1, + True, + Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True, dtype=np.float64), + ) +) - # dl = dl_type(ds,num_workers=0, batch_size=batch_size) - - # batch=first(dl) - - # self.assertEqual(len(batch[self.key].applied_operations), batch_size) - # self.assertGreater(len(batch[self.key].applied_operations[0]), 0, "Applied operations are missing") - - # # batch[CommonKeys.PRED] = batch[self.key] - # # post_batch=engine_apply_transform(batch=batch,output={},transform=self.postprocessing) +TESTS.append( + ( + "RandRotated 3d", + "3D", + 1e-1, + True, + RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1, dtype=np.float64), # type: ignore + ) +) +TESTS.append(("Transposed 2d", "2D", 0, False, Transposed(KEYS, [0, 2, 1]))) # channel=0 + +TESTS.append(("Transposed 3d", "3D", 0, False, Transposed(KEYS, [0, 3, 1, 2]))) # channel=0 + +TESTS.append( + ( + "Affine 3d", + "3D", + 1e-1, + True, + Affined( + KEYS, + spatial_size=[155, 179, 192], + rotate_params=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_params=[0.5, 0.5], + translate_params=[10, 5, -4], + scale_params=[0.8, 1.3], + ), + ) +) + +TESTS.append( + ( + "RandAffine 3d", + "3D", + 1e-1, + True, + RandAffined( + KEYS, + [155, 179, 192], + prob=1, + padding_mode="zeros", + rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], + shear_range=[(0.5, 0.5)], + translate_range=[10, 5, -4], + scale_range=[(0.8, 1.2), (0.9, 1.3)], + ), + ) +) + +TESTS.append(("RandAffine 3d", "3D", 0, True, RandAffined(KEYS, spatial_size=None, prob=0))) + +TESTS.append( + ( + "RandCropByLabelClassesd 2d", + "2D", + 1e-7, + True, + RandCropByLabelClassesd(KEYS, "label", (99, 96), ratios=[1, 2, 3, 4, 5], num_classes=5, num_samples=10), + ) +) + +TESTS.append( + ("RandCropByPosNegLabeld 2d", "2D", 1e-7, True, RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10)) +) + +TESTS.append(("RandSpatialCropSamplesd 2d", "2D", 1e-7, True, RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10))) + +TESTS.append(("RandWeightedCropd 2d", "2D", 1e-7, True, RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10))) + +TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], t[3], Compose(Compose(t[4:]))) for t in TESTS] + +TESTS = TESTS + TESTS_COMPOSE_X2 + +NUM_SAMPLES = 5 +N_SAMPLES_TESTS = [ + [RandCropByLabelClassesd(KEYS, "label", (110, 99), [1, 2, 3, 4, 5], num_classes=5, num_samples=NUM_SAMPLES)], + [RandCropByPosNegLabeld(KEYS, "label", (110, 99), num_samples=NUM_SAMPLES)], + [RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=NUM_SAMPLES, random_size=False)], + [RandWeightedCropd(KEYS, "label", (90, 91), num_samples=NUM_SAMPLES)], +] + + +def no_collation(x): + return x + + +class TestInverse(unittest.TestCase): + """Test inverse methods. + + If tests are failing, the following function might be useful for displaying + `x`, `fx`, `f⁻¹fx` and `x - f⁻¹fx`. + + .. code-block:: python + + def plot_im(orig, fwd_bck, fwd): + import matplotlib.pyplot as plt + diff_orig_fwd_bck = orig - fwd_bck + ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] + titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] + fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) + vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) + vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) + for im, title, ax in zip(ims_to_show, titles, axes): + _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) + im = np.squeeze(np.array(im)) + while im.ndim > 2: + im = im[..., im.shape[-1] // 2] + im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) + ax.set_title(title, fontsize=25) + ax.axis("off") + fig.colorbar(im_show, ax=ax) + plt.show() + + This can then be added to the exception: + + .. code-block:: python + + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + if orig[0].ndim > 1: + plot_im(orig, fwd_bck, unmodified) + """ + def setUp(self): + if not has_nib: + self.skipTest("nibabel required for test_inverse") - @parameterized.expand(TEST_DEVICES) - def test_workflow(self, device): - test_data = [{self.key: self.img.clone().to(device)} for _ in range(4)] - batch_size=2 - ds = Dataset(test_data, transform=self.preprocessing) - dl = ThreadDataLoader(ds, num_workers=0, batch_size=batch_size) - # dl = DataLoader(ds,num_workers=0, batch_size=batch_size) + set_determinism(seed=0) - class AssertAppliedOps(torch.nn.Module): - def forward(self,x): - assert len(x.applied_operations)==x.shape[0] - assert all(len(a)>0 for a in x.applied_operations) - return x + self.all_data = {} - # def _print(x): - # print(type(x), id(x), x.shape, len(x.applied_operations)) - # del x.applied_operations[:] - # return x + affine = make_rand_affine() + affine[0] *= 2 - # postprocessing = Compose([ - # Lambdad(self.key, func=_print), - # ]) + for size in [10, 11]: + # pad 5 onto both ends so that cropping can be lossless + im_1d = np.pad(np.arange(size), 5)[None] + name = "1D even" if size % 2 == 0 else "1D odd" + self.all_data[name] = { + "image": torch.as_tensor(np.array(im_1d, copy=True)), + "label": torch.as_tensor(np.array(im_1d, copy=True)), + "other": torch.as_tensor(np.array(im_1d, copy=True)), + } - + im_2d_fname, seg_2d_fname = (make_nifti_image(i) for i in create_test_image_2d(101, 100)) + im_3d_fname, seg_3d_fname = (make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107)) - evaluator = SupervisedEvaluator( - device=device, - network=AssertAppliedOps(), - postprocessing=self.postprocessing, - val_data_loader=dl + load_ims = Compose( + [LoadImaged(KEYS), EnsureChannelFirstd(KEYS, channel_dim="no_channel"), FromMetaTensord(KEYS)] + ) + self.all_data["2D"] = load_ims({"image": im_2d_fname, "label": seg_2d_fname}) + self.all_data["3D"] = load_ims({"image": im_3d_fname, "label": seg_3d_fname}) + + def tearDown(self): + set_determinism(seed=None) + + def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): + for key in keys: + orig = orig_d[key] + fwd_bck = fwd_bck_d[key] + if isinstance(fwd_bck, torch.Tensor): + fwd_bck = fwd_bck.cpu().numpy() + unmodified = unmodified_d[key] + if isinstance(orig, np.ndarray): + mean_diff = np.mean(np.abs(orig - fwd_bck)) + resized = ResizeWithPadOrCrop(orig.shape[1:])(unmodified) + if isinstance(resized, torch.Tensor): + resized = resized.detach().cpu().numpy() + unmodded_diff = np.mean(np.abs(orig - resized)) + try: + self.assertLessEqual(mean_diff, acceptable_diff) + except AssertionError: + print( + f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" + ) + if orig[0].ndim == 1: + print("orig", orig[0]) + print("fwd_bck", fwd_bck[0]) + print("unmod", unmodified[0]) + raise + + @parameterized.expand(TESTS) + def test_inverse(self, _, data_name, acceptable_diff, is_meta, *transforms): + name = _ + + data = self.all_data[data_name] + if is_meta: + data = ToMetaTensord(KEYS)(data) + + forwards = [data.copy()] + + # Apply forwards + for t in transforms: + if isinstance(t, Randomizable): + t.set_random_state(seed=get_seed()) + forwards.append(t(forwards[-1])) + + # Apply inverses + fwd_bck = forwards[-1].copy() + for i, t in enumerate(reversed(transforms)): + if isinstance(t, InvertibleTransform): + if isinstance(fwd_bck, list): + for j, _fwd_bck in enumerate(fwd_bck): + fwd_bck = t.inverse(_fwd_bck) + self.check_inverse( + name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1][j], acceptable_diff + ) + else: + fwd_bck = t.inverse(fwd_bck) + self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) + + # skip this test if multiprocessing uses 'spawn', as the check is only basic anyway + @skipUnless(torch.multiprocessing.get_start_method() == "spawn", "requires spawn") + def test_fail(self): + t1 = SpatialPadd("image", [10, 5]) + data = t1(self.all_data["2D"]) + + # Check that error is thrown when inverse are used out of order. + t2 = ResizeWithPadOrCropd("image", [10, 5]) + with self.assertRaises(RuntimeError): + t2.inverse(data) + + @parameterized.expand(N_SAMPLES_TESTS) + def test_inverse_inferred_seg(self, extra_transform): + test_data = [] + for _ in range(20): + image, label = create_test_image_2d(100, 101) + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 10 + # num workers = 0 for mac + num_workers = 2 if sys.platform == "linux" else 0 + transforms = Compose( + [EnsureChannelFirstd(KEYS, channel_dim="no_channel"), SpatialPadd(KEYS, (150, 153)), extra_transform] ) - # def tensor_struct_info(tstruct): - # if isinstance(tstruct, torch.Tensor): - # return f"{id(tstruct):x} {tuple(tstruct.shape)} {tstruct.dtype} {len(getattr(tstruct,"applied_operations",[]))}" - # elif isinstance(tstruct, Sequence): - # return list(map(tensor_struct_info, tstruct)) - # elif isinstance(tstruct, Mapping): - # return {k: tensor_struct_info(v) for k, v in tstruct.items()} - # else: - # return repr(tstruct) - - # @evaluator.on(IterationEvents.MODEL_COMPLETED) - # def _run_postprocessing(engine:SupervisedEvaluator) -> None: - # print("\n===================\n") - # # print("Batch:",dumps(tensor_struct_info(engine.state.batch),indent=2),flush=True) - # print("Output:",dumps(tensor_struct_info(engine.state.output),indent=2),flush=True) - - # for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): - # # print("Post:",dumps(tensor_struct_info(o),indent=2),flush=True) - # engine.state.batch[i], engine.state.output[i] = engine_apply_transform(b, o, self.postprocessing) - - # # evaluator._register_postprocessing(self.postprocessing) - - evaluator.run() - - # self.assertTrue(len(evaluator.state.batch[0][self.key].applied_operations)>0) - + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(1,)).to(device) + + data = first(loader) + self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) + + labels = data["label"].to(device) + self.assertIsInstance(labels, MetaTensor) + segs = model(labels).detach().cpu() + segs_decollated = decollate_batch(segs) + self.assertIsInstance(segs_decollated[0], MetaTensor) + # inverse of individual segmentation + seg_metatensor = first(segs_decollated) + # test to convert interpolation mode for 1 data of model output batch + convert_applied_interp_mode(seg_metatensor.applied_operations, mode="nearest", align_corners=None) + + # manually invert the last crop samples + xform = seg_metatensor.applied_operations.pop(-1) + shape_before_extra_xform = xform["orig_size"] + resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform) + with resizer.trace_transform(False): + seg_metatensor = resizer(seg_metatensor) + no_ops_id_tensor = reset_ops_id(deepcopy(seg_metatensor)) + + with allow_missing_keys_mode(transforms): + inv_seg = transforms.inverse({"label": seg_metatensor})["label"] + inv_seg_1 = transforms.inverse({"label": no_ops_id_tensor})["label"] + self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + self.assertEqual(inv_seg_1.shape[1:], test_data[0]["label"].shape) + + # # Inverse of batch + # batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True) + # with allow_missing_keys_mode(transforms): + # inv_batch = batch_inverter(first(loader)) + # self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) if __name__ == "__main__": diff --git a/tests/transforms/inverse/test_inverse_dict.py b/tests/transforms/inverse/test_inverse_dict.py new file mode 100644 index 0000000000..1d91b8da3d --- /dev/null +++ b/tests/transforms/inverse/test_inverse_dict.py @@ -0,0 +1,111 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from itertools import product +import unittest + +import torch +from parameterized import parameterized + +from monai.data import MetaTensor, create_test_image_2d, Dataset, ThreadDataLoader, DataLoader +from monai.engines.evaluator import SupervisedEvaluator +from monai.transforms import Compose, EnsureChannelFirstd, Invertd, Spacingd +from monai.transforms.utility.dictionary import Lambdad +from monai.utils.enums import CommonKeys + +from tests.test_utils import TEST_DEVICES + + +class TestInvertDict(unittest.TestCase): + + def setUp(self): + self.orig_size = (60, 60) + img, _ = create_test_image_2d(*self.orig_size, 2, 10, num_seg_classes=2) + self.img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0]}) + self.key = CommonKeys.IMAGE + self.pred = CommonKeys.PRED + self.new_pixdim = 2.0 + + self.preprocessing = Compose([EnsureChannelFirstd(self.key), Spacingd(self.key, pixdim=[self.new_pixdim] * 2)]) + + self.postprocessing = Compose( + [ + Lambdad(self.pred, func=lambda x: x), # tests that added postprocess transforms don't confuse Invertd + Invertd(self.pred, transform=self.preprocessing, orig_keys=self.key), + ] + ) + + @parameterized.expand(TEST_DEVICES) + def test_simple_processing(self, device): + """ + Tests postprocessing operations perform correctly, in particular that `Invertd` does inversion correctly. + + This will apply the preprocessing sequence which resizes the result, then the postprocess sequence which + returns it to the original shape using Invertd. This tests that the shape of the output is the same as the + original image. This will also test that Invertd doesn't get confused if transforms in the postprocessing + sequence are tracing and so adding information to `applied_operations`, this is what `Lambdad` is doing in + `self.postprocessing`. + """ + + item = {self.key: self.img.to(device)} + pre = self.preprocessing(item) + + nw = int(self.orig_size[0] / self.new_pixdim) + nh = int(self.orig_size[1] / self.new_pixdim) + + self.assertTupleEqual(pre[self.key].shape, (1, nh, nw), "Pre-processing did not reshape input correctly") + self.assertTrue(len(pre[self.key].applied_operations) > 0, "Pre-processing transforms did not trace correctly") + + pre[self.pred] = pre[self.key] # the inputs are the prediction for this test + + post = self.postprocessing(pre) + + self.assertTupleEqual( + post[self.pred].shape, (1, *self.orig_size), "Result does not have same shape as original input" + ) + + @parameterized.expand(product(sum(TEST_DEVICES, []), [True, False])) + def test_workflow(self, device, use_threads): + """ + This tests the interaction between pre and postprocesing transform sequences being executed in parallel. + + When the `ThreadDataLoader` is used to load batches, this is done in parallel at times with the execution of + the post-process transform sequence. Previously this encountered a race condition at times because the + `TraceableTransform.tracing` variables of transforms was being toggled in different threads, so at times a + pre-process transform wouldn't trace correctly and so confuse `Invertd`. Using a `SupervisedEvaluator` is + the best way to induce this race condition, other methods didn't get the timing right.. + """ + batch_size = 2 + ds_size = 4 + test_data = [{self.key: self.img.clone().to(device)} for _ in range(ds_size)] + ds = Dataset(test_data, transform=self.preprocessing) + dl_type = ThreadDataLoader if use_threads else DataLoader + dl = dl_type(ds, num_workers=0, batch_size=batch_size) + + class AssertAppliedOps(torch.nn.Module): + def forward(self, x): + assert len(x.applied_operations) == x.shape[0] + assert all(len(a) > 0 for a in x.applied_operations) + return x + + evaluator = SupervisedEvaluator( + device=device, network=AssertAppliedOps(), postprocessing=self.postprocessing, val_data_loader=dl + ) + + evaluator.run() + + self.assertTupleEqual(evaluator.state.output[0][self.pred].shape, (1, *self.orig_size)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transforms/test_invert.py b/tests/transforms/inverse/test_invert.py similarity index 100% rename from tests/transforms/test_invert.py rename to tests/transforms/inverse/test_invert.py diff --git a/tests/transforms/test_invertd.py b/tests/transforms/inverse/test_invertd.py similarity index 100% rename from tests/transforms/test_invertd.py rename to tests/transforms/inverse/test_invertd.py diff --git a/tests/transforms/test_inverse.py b/tests/transforms/test_inverse.py deleted file mode 100644 index 01d32e4baf..0000000000 --- a/tests/transforms/test_inverse.py +++ /dev/null @@ -1,521 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import random -import sys -import unittest -from copy import deepcopy -from functools import partial -from typing import TYPE_CHECKING -from unittest.case import skipUnless - -import numpy as np -import torch -from parameterized import parameterized - -from monai.data import CacheDataset, DataLoader, MetaTensor, create_test_image_2d, create_test_image_3d, decollate_batch -from monai.networks.nets import UNet -from monai.transforms import ( - Affined, - BorderPadd, - CenterScaleCropd, - CenterSpatialCropd, - Compose, - CropForegroundd, - DivisiblePadd, - EnsureChannelFirstd, - Flipd, - FromMetaTensord, - InvertibleTransform, - Lambdad, - LoadImaged, - Orientationd, - RandAffined, - RandAxisFlipd, - RandCropByLabelClassesd, - RandCropByPosNegLabeld, - RandFlipd, - RandLambdad, - Randomizable, - RandRotate90d, - RandRotated, - RandSpatialCropd, - RandSpatialCropSamplesd, - RandWeightedCropd, - RandZoomd, - Resized, - ResizeWithPadOrCrop, - ResizeWithPadOrCropd, - Rotate90d, - Rotated, - Spacingd, - SpatialCropd, - SpatialPadd, - ToMetaTensord, - Transposed, - Zoomd, - allow_missing_keys_mode, - convert_applied_interp_mode, - reset_ops_id, -) -from monai.utils import first, get_seed, optional_import, set_determinism -from tests.test_utils import make_nifti_image, make_rand_affine - -if TYPE_CHECKING: - has_nib = True -else: - _, has_nib = optional_import("nibabel") - -KEYS = ["image", "label"] - -TESTS: list[tuple] = [] - -# For pad, start with odd/even images and add odd/even amounts -for name in ("1D even", "1D odd"): - for val in (3, 4): - for t in ( - partial(SpatialPadd, spatial_size=val, method="symmetric"), - partial(SpatialPadd, spatial_size=val, method="end"), - partial(BorderPadd, spatial_border=[val, val + 1]), - partial(DivisiblePadd, k=val), - partial(ResizeWithPadOrCropd, spatial_size=20 + val), - partial(CenterSpatialCropd, roi_size=10 + val), - partial(CenterScaleCropd, roi_scale=0.8), - partial(CropForegroundd, source_key="label"), - partial(SpatialCropd, roi_center=10, roi_size=10 + val), - partial(SpatialCropd, roi_center=11, roi_size=10 + val), - partial(SpatialCropd, roi_start=val, roi_end=17), - partial(SpatialCropd, roi_start=val, roi_end=16), - partial(RandSpatialCropd, roi_size=12 + val), - partial(ResizeWithPadOrCropd, spatial_size=21 - val), - ): - TESTS.append((t.func.__name__ + name, name, 0, True, t(KEYS))) # type: ignore - -# non-sensical tests: crop bigger or pad smaller or -ve values -for t in ( - partial(DivisiblePadd, k=-3), - partial(CenterSpatialCropd, roi_size=-3), - partial(RandSpatialCropd, roi_size=-3), - partial(SpatialPadd, spatial_size=15), - partial(BorderPadd, spatial_border=[15, 16]), - partial(CenterSpatialCropd, roi_size=30), - partial(SpatialCropd, roi_center=10, roi_size=100), - partial(SpatialCropd, roi_start=3, roi_end=100), -): - TESTS.append((t.func.__name__ + "bad 1D even", "1D even", 0, True, t(KEYS))) # type: ignore - -TESTS.append( - ( - "SpatialPadd (x2) 2d", - "2D", - 0, - True, - SpatialPadd(KEYS, spatial_size=[111, 113], method="end"), - SpatialPadd(KEYS, spatial_size=[118, 117]), - ) -) - -TESTS.append(("SpatialPadd 3d", "3D", 0, True, SpatialPadd(KEYS, spatial_size=[112, 113, 116]))) - -TESTS.append(("SpatialCropd 2d", "2D", 0, True, SpatialCropd(KEYS, [49, 51], [90, 89]))) - -TESTS.append( - ( - "SpatialCropd 3d", - "3D", - 0, - True, - SpatialCropd(KEYS, roi_slices=[slice(s, e) for s, e in zip([None, None, -99], [None, -2, None])]), - ) -) - -TESTS.append(("SpatialCropd 2d", "2D", 0, True, SpatialCropd(KEYS, [49, 51], [390, 89]))) - -TESTS.append(("SpatialCropd 3d", "3D", 0, True, SpatialCropd(KEYS, [49, 51, 44], [90, 89, 93]))) - -TESTS.append(("RandSpatialCropd 2d", "2D", 0, True, RandSpatialCropd(KEYS, [96, 93], None, True, False))) - -TESTS.append(("RandSpatialCropd 3d", "3D", 0, True, RandSpatialCropd(KEYS, [96, 93, 92], None, False, False))) - -TESTS.append(("BorderPadd 2d", "2D", 0, True, BorderPadd(KEYS, [3, 7, 2, 5]))) - -TESTS.append(("BorderPadd 2d", "2D", 0, True, BorderPadd(KEYS, [3, 7]))) - -TESTS.append(("BorderPadd 3d", "3D", 0, True, BorderPadd(KEYS, [4]))) - -TESTS.append(("DivisiblePadd 2d", "2D", 0, True, DivisiblePadd(KEYS, k=4))) - -TESTS.append(("DivisiblePadd 3d", "3D", 0, True, DivisiblePadd(KEYS, k=[4, 8, 11]))) - -TESTS.append(("CenterSpatialCropd 2d", "2D", 0, True, CenterSpatialCropd(KEYS, roi_size=95))) - -TESTS.append(("CenterSpatialCropd 3d", "3D", 0, True, CenterSpatialCropd(KEYS, roi_size=[95, 97, 98]))) - -TESTS.append(("CropForegroundd 2d", "2D", 0, True, CropForegroundd(KEYS, source_key="label", margin=2))) - -TESTS.append(("CropForegroundd 3d", "3D", 0, True, CropForegroundd(KEYS, source_key="label", k_divisible=[5, 101, 2]))) - -TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, True, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) - -TESTS.append(("Flipd 3d", "3D", 0, True, Flipd(KEYS, [1, 2]))) -TESTS.append(("Flipd 3d", "3D", 0, True, Flipd(KEYS, [1, 2]))) - -TESTS.append(("RandFlipd 3d", "3D", 0, True, RandFlipd(KEYS, 1, [1, 2]))) - -TESTS.append(("RandAxisFlipd 3d", "3D", 0, True, RandAxisFlipd(KEYS, 1))) -TESTS.append(("RandAxisFlipd 3d", "3D", 0, True, RandAxisFlipd(KEYS, 1))) - -for acc in [True, False]: - TESTS.append(("Orientationd 3d", "3D", 0, True, Orientationd(KEYS, "RAS", as_closest_canonical=acc))) - -TESTS.append(("Rotate90d 2d", "2D", 0, True, Rotate90d(KEYS))) - -TESTS.append(("Rotate90d 3d", "3D", 0, True, Rotate90d(KEYS, k=2, spatial_axes=(1, 2)))) - -TESTS.append(("RandRotate90d 3d", "3D", 0, True, RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)))) - -TESTS.append(("Spacingd 3d", "3D", 3e-2, True, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) - -TESTS.append(("Resized 2d", "2D", 2e-1, True, Resized(KEYS, [50, 47]))) - -TESTS.append(("Resized 3d", "3D", 5e-2, True, Resized(KEYS, [201, 150, 78]))) - -TESTS.append(("Resized longest 2d", "2D", 2e-1, True, Resized(KEYS, 47, "longest", "area"))) - -TESTS.append(("Resized longest 3d", "3D", 5e-2, True, Resized(KEYS, 201, "longest", "trilinear", True))) - -TESTS.append( - ("Lambdad 2d", "2D", 5e-2, False, Lambdad(KEYS, func=lambda x: x + 5, inv_func=lambda x: x - 5, overwrite=True)) -) - -TESTS.append( - ( - "RandLambdad 3d", - "3D", - 5e-2, - False, - RandLambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True, prob=0.5), - ) -) - -TESTS.append(("Zoomd 1d", "1D odd", 0, True, Zoomd(KEYS, zoom=2, keep_size=False))) - -TESTS.append(("Zoomd 2d", "2D", 2e-1, True, Zoomd(KEYS, zoom=0.9))) - -TESTS.append(("Zoomd 3d", "3D", 3e-2, True, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False))) - -TESTS.append(("RandZoom 3d", "3D", 9e-2, True, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) - -TESTS.append(("RandRotated, prob 0", "2D", 0, True, RandRotated(KEYS, prob=0, dtype=np.float64))) - -TESTS.append( - ( - "Rotated 2d", - "2D", - 8e-2, - True, - Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False, dtype=np.float64), - ) -) - -TESTS.append( - ( - "Rotated 3d", - "3D", - 1e-1, - True, - Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True, dtype=np.float64), - ) -) - -TESTS.append( - ( - "RandRotated 3d", - "3D", - 1e-1, - True, - RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1, dtype=np.float64), # type: ignore - ) -) - -TESTS.append(("Transposed 2d", "2D", 0, False, Transposed(KEYS, [0, 2, 1]))) # channel=0 - -TESTS.append(("Transposed 3d", "3D", 0, False, Transposed(KEYS, [0, 3, 1, 2]))) # channel=0 - -TESTS.append( - ( - "Affine 3d", - "3D", - 1e-1, - True, - Affined( - KEYS, - spatial_size=[155, 179, 192], - rotate_params=[np.pi / 6, -np.pi / 5, np.pi / 7], - shear_params=[0.5, 0.5], - translate_params=[10, 5, -4], - scale_params=[0.8, 1.3], - ), - ) -) - -TESTS.append( - ( - "RandAffine 3d", - "3D", - 1e-1, - True, - RandAffined( - KEYS, - [155, 179, 192], - prob=1, - padding_mode="zeros", - rotate_range=[np.pi / 6, -np.pi / 5, np.pi / 7], - shear_range=[(0.5, 0.5)], - translate_range=[10, 5, -4], - scale_range=[(0.8, 1.2), (0.9, 1.3)], - ), - ) -) - -TESTS.append(("RandAffine 3d", "3D", 0, True, RandAffined(KEYS, spatial_size=None, prob=0))) - -TESTS.append( - ( - "RandCropByLabelClassesd 2d", - "2D", - 1e-7, - True, - RandCropByLabelClassesd(KEYS, "label", (99, 96), ratios=[1, 2, 3, 4, 5], num_classes=5, num_samples=10), - ) -) - -TESTS.append( - ("RandCropByPosNegLabeld 2d", "2D", 1e-7, True, RandCropByPosNegLabeld(KEYS, "label", (99, 96), num_samples=10)) -) - -TESTS.append(("RandSpatialCropSamplesd 2d", "2D", 1e-7, True, RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=10))) - -TESTS.append(("RandWeightedCropd 2d", "2D", 1e-7, True, RandWeightedCropd(KEYS, "label", (90, 91), num_samples=10))) - -TESTS_COMPOSE_X2 = [(t[0] + " Compose", t[1], t[2], t[3], Compose(Compose(t[4:]))) for t in TESTS] - -TESTS = TESTS + TESTS_COMPOSE_X2 - -NUM_SAMPLES = 5 -N_SAMPLES_TESTS = [ - [RandCropByLabelClassesd(KEYS, "label", (110, 99), [1, 2, 3, 4, 5], num_classes=5, num_samples=NUM_SAMPLES)], - [RandCropByPosNegLabeld(KEYS, "label", (110, 99), num_samples=NUM_SAMPLES)], - [RandSpatialCropSamplesd(KEYS, (90, 91), num_samples=NUM_SAMPLES, random_size=False)], - [RandWeightedCropd(KEYS, "label", (90, 91), num_samples=NUM_SAMPLES)], -] - - -def no_collation(x): - return x - - -class TestInverse(unittest.TestCase): - """Test inverse methods. - - If tests are failing, the following function might be useful for displaying - `x`, `fx`, `f⁻¹fx` and `x - f⁻¹fx`. - - .. code-block:: python - - def plot_im(orig, fwd_bck, fwd): - import matplotlib.pyplot as plt - diff_orig_fwd_bck = orig - fwd_bck - ims_to_show = [orig, fwd, fwd_bck, diff_orig_fwd_bck] - titles = ["x", "fx", "f⁻¹fx", "x - f⁻¹fx"] - fig, axes = plt.subplots(1, 4, gridspec_kw={"width_ratios": [i.shape[1] for i in ims_to_show]}) - vmin = min(np.array(i).min() for i in [orig, fwd_bck, fwd]) - vmax = max(np.array(i).max() for i in [orig, fwd_bck, fwd]) - for im, title, ax in zip(ims_to_show, titles, axes): - _vmin, _vmax = (vmin, vmax) if id(im) != id(diff_orig_fwd_bck) else (None, None) - im = np.squeeze(np.array(im)) - while im.ndim > 2: - im = im[..., im.shape[-1] // 2] - im_show = ax.imshow(np.squeeze(im), vmin=_vmin, vmax=_vmax) - ax.set_title(title, fontsize=25) - ax.axis("off") - fig.colorbar(im_show, ax=ax) - plt.show() - - This can then be added to the exception: - - .. code-block:: python - - except AssertionError: - print( - f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" - ) - if orig[0].ndim > 1: - plot_im(orig, fwd_bck, unmodified) - """ - - def setUp(self): - if not has_nib: - self.skipTest("nibabel required for test_inverse") - - set_determinism(seed=0) - - self.all_data = {} - - affine = make_rand_affine() - affine[0] *= 2 - - for size in [10, 11]: - # pad 5 onto both ends so that cropping can be lossless - im_1d = np.pad(np.arange(size), 5)[None] - name = "1D even" if size % 2 == 0 else "1D odd" - self.all_data[name] = { - "image": torch.as_tensor(np.array(im_1d, copy=True)), - "label": torch.as_tensor(np.array(im_1d, copy=True)), - "other": torch.as_tensor(np.array(im_1d, copy=True)), - } - - im_2d_fname, seg_2d_fname = (make_nifti_image(i) for i in create_test_image_2d(101, 100)) - im_3d_fname, seg_3d_fname = (make_nifti_image(i, affine) for i in create_test_image_3d(100, 101, 107)) - - load_ims = Compose( - [LoadImaged(KEYS), EnsureChannelFirstd(KEYS, channel_dim="no_channel"), FromMetaTensord(KEYS)] - ) - self.all_data["2D"] = load_ims({"image": im_2d_fname, "label": seg_2d_fname}) - self.all_data["3D"] = load_ims({"image": im_3d_fname, "label": seg_3d_fname}) - - def tearDown(self): - set_determinism(seed=None) - - def check_inverse(self, name, keys, orig_d, fwd_bck_d, unmodified_d, acceptable_diff): - for key in keys: - orig = orig_d[key] - fwd_bck = fwd_bck_d[key] - if isinstance(fwd_bck, torch.Tensor): - fwd_bck = fwd_bck.cpu().numpy() - unmodified = unmodified_d[key] - if isinstance(orig, np.ndarray): - mean_diff = np.mean(np.abs(orig - fwd_bck)) - resized = ResizeWithPadOrCrop(orig.shape[1:])(unmodified) - if isinstance(resized, torch.Tensor): - resized = resized.detach().cpu().numpy() - unmodded_diff = np.mean(np.abs(orig - resized)) - try: - self.assertLessEqual(mean_diff, acceptable_diff) - except AssertionError: - print( - f"Failed: {name}. Mean diff = {mean_diff} (expected <= {acceptable_diff}), unmodified diff: {unmodded_diff}" - ) - if orig[0].ndim == 1: - print("orig", orig[0]) - print("fwd_bck", fwd_bck[0]) - print("unmod", unmodified[0]) - raise - - @parameterized.expand(TESTS) - def test_inverse(self, _, data_name, acceptable_diff, is_meta, *transforms): - name = _ - - data = self.all_data[data_name] - if is_meta: - data = ToMetaTensord(KEYS)(data) - - forwards = [data.copy()] - - # Apply forwards - for t in transforms: - if isinstance(t, Randomizable): - t.set_random_state(seed=get_seed()) - forwards.append(t(forwards[-1])) - - # Apply inverses - fwd_bck = forwards[-1].copy() - for i, t in enumerate(reversed(transforms)): - if isinstance(t, InvertibleTransform): - if isinstance(fwd_bck, list): - for j, _fwd_bck in enumerate(fwd_bck): - fwd_bck = t.inverse(_fwd_bck) - self.check_inverse( - name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1][j], acceptable_diff - ) - else: - fwd_bck = t.inverse(fwd_bck) - self.check_inverse(name, data.keys(), forwards[-i - 2], fwd_bck, forwards[-1], acceptable_diff) - - # skip this test if multiprocessing uses 'spawn', as the check is only basic anyway - @skipUnless(torch.multiprocessing.get_start_method() == "spawn", "requires spawn") - def test_fail(self): - t1 = SpatialPadd("image", [10, 5]) - data = t1(self.all_data["2D"]) - - # Check that error is thrown when inverse are used out of order. - t2 = ResizeWithPadOrCropd("image", [10, 5]) - with self.assertRaises(RuntimeError): - t2.inverse(data) - - @parameterized.expand(N_SAMPLES_TESTS) - def test_inverse_inferred_seg(self, extra_transform): - test_data = [] - for _ in range(20): - image, label = create_test_image_2d(100, 101) - test_data.append({"image": image, "label": label.astype(np.float32)}) - - batch_size = 10 - # num workers = 0 for mac - num_workers = 2 if sys.platform == "linux" else 0 - transforms = Compose( - [EnsureChannelFirstd(KEYS, channel_dim="no_channel"), SpatialPadd(KEYS, (150, 153)), extra_transform] - ) - - dataset = CacheDataset(test_data, transform=transforms, progress=False) - loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) - - device = "cuda" if torch.cuda.is_available() else "cpu" - model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(1,)).to(device) - - data = first(loader) - self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) - - labels = data["label"].to(device) - self.assertIsInstance(labels, MetaTensor) - segs = model(labels).detach().cpu() - segs_decollated = decollate_batch(segs) - self.assertIsInstance(segs_decollated[0], MetaTensor) - # inverse of individual segmentation - seg_metatensor = first(segs_decollated) - # test to convert interpolation mode for 1 data of model output batch - convert_applied_interp_mode(seg_metatensor.applied_operations, mode="nearest", align_corners=None) - - # manually invert the last crop samples - xform = seg_metatensor.applied_operations.pop(-1) - shape_before_extra_xform = xform["orig_size"] - resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform) - with resizer.trace_transform(False): - seg_metatensor = resizer(seg_metatensor) - no_ops_id_tensor = reset_ops_id(deepcopy(seg_metatensor)) - - with allow_missing_keys_mode(transforms): - inv_seg = transforms.inverse({"label": seg_metatensor})["label"] - inv_seg_1 = transforms.inverse({"label": no_ops_id_tensor})["label"] - self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) - self.assertEqual(inv_seg_1.shape[1:], test_data[0]["label"].shape) - - # # Inverse of batch - # batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True) - # with allow_missing_keys_mode(transforms): - # inv_batch = batch_inverter(first(loader)) - # self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) - - -if __name__ == "__main__": - unittest.main() From 6b9e4d6aae4670ee242e05572ab1a783e675ad5a Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Tue, 8 Apr 2025 18:40:59 +0100 Subject: [PATCH 03/12] Fixes --- monai/transforms/inverse.py | 12 +++++------ tests/transforms/inverse/test_inverse_dict.py | 21 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index a47a147b23..47665fbfb5 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -11,11 +11,11 @@ from __future__ import annotations +import threading import warnings from collections.abc import Hashable, Mapping from contextlib import contextmanager from typing import Any -import threading import torch @@ -76,7 +76,7 @@ def _init_trace_threadlocal(self): if not hasattr(self, "_tracing"): self._tracing = threading.local() - # This is True while the above initialising _tracing is False when this is + # This is True while the above initialising _tracing is False when this is # called from a different thread than the one initialising _tracing. if not hasattr(self._tracing, "value"): self._tracing.value = MONAIEnvVars.trace_transform() != "0" @@ -87,7 +87,7 @@ def tracing(self) -> bool: Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != "0"`. """ self._init_trace_threadlocal() - return self._tracing.value + return bool(self._tracing.value) @tracing.setter def tracing(self, val: bool): @@ -338,18 +338,18 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr # Find the last transform whose name matches that of this class, this allows Invertd to ignore applied # operations added by transforms it is not trying to invert, ie. those added in postprocessing. - idx=-1 + idx = -1 for i in reversed(range(len(all_transforms))): xform_name = all_transforms[i].get(TraceKeys.CLASS_NAME, "") if xform_name == self.__class__.__name__: - idx=i # if nothing found, idx remains -1 so replicating previous behaviour + idx = i # if nothing found, idx remains -1 so replicating previous behaviour break if not all_transforms: raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'") if check: - if not (-len(all_transforms)<=idx Date: Tue, 8 Apr 2025 18:44:32 +0100 Subject: [PATCH 04/12] DCO Remediation Commit for Eric Kerfoot I, Eric Kerfoot , hereby add my Signed-off-by to this commit: 11be4cb5cb3b810d92d2fc6329fff96a990abe13 I, Eric Kerfoot , hereby add my Signed-off-by to this commit: 6b9e4d6aae4670ee242e05572ab1a783e675ad5a Signed-off-by: Eric Kerfoot From 000b9010023c698f8ec483c09121aabeeedbc244 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 9 Apr 2025 19:25:40 +0100 Subject: [PATCH 05/12] Dependency fix Signed-off-by: Eric Kerfoot --- tests/transforms/inverse/test_inverse_dict.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/transforms/inverse/test_inverse_dict.py b/tests/transforms/inverse/test_inverse_dict.py index 91528ac39a..1998801663 100644 --- a/tests/transforms/inverse/test_inverse_dict.py +++ b/tests/transforms/inverse/test_inverse_dict.py @@ -22,7 +22,7 @@ from monai.transforms import Compose, EnsureChannelFirstd, Invertd, Spacingd from monai.transforms.utility.dictionary import Lambdad from monai.utils.enums import CommonKeys -from tests.test_utils import TEST_DEVICES +from tests.test_utils import TEST_DEVICES, SkipIfNoModule class TestInvertDict(unittest.TestCase): @@ -73,6 +73,7 @@ def test_simple_processing(self, device): post[self.pred].shape, (1, *self.orig_size), "Result does not have same shape as original input" ) + @SkipIfNoModule("ignite") @parameterized.expand(product(sum(TEST_DEVICES, []), [True, False])) def test_workflow(self, device, use_threads): """ From 5da9c0b92a89f619f6509e1d8c71a9a9cd33c0b5 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 9 Apr 2025 19:28:21 +0100 Subject: [PATCH 06/12] Removing change to get_most_recent_transform Signed-off-by: Eric Kerfoot --- monai/transforms/inverse.py | 16 ++-------------- tests/transforms/inverse/test_inverse_dict.py | 1 - 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 47665fbfb5..83ec7f1ab2 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -336,25 +336,13 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr else: raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.") - # Find the last transform whose name matches that of this class, this allows Invertd to ignore applied - # operations added by transforms it is not trying to invert, ie. those added in postprocessing. - idx = -1 - for i in reversed(range(len(all_transforms))): - xform_name = all_transforms[i].get(TraceKeys.CLASS_NAME, "") - if xform_name == self.__class__.__name__: - idx = i # if nothing found, idx remains -1 so replicating previous behaviour - break - if not all_transforms: raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'") if check: - if not (-len(all_transforms) <= idx < len(all_transforms)): - raise IndexError(f"Index '{idx}' not valid for list of applied operations '{all_transforms}'") - - self.check_transforms_match(all_transforms[idx]) + self.check_transforms_match(all_transforms[-1]) - return all_transforms.pop(idx) if pop else all_transforms[idx] + return all_transforms.pop(-1) if pop else all_transforms[-1] def pop_transform(self, data, key: Hashable = None, check: bool = True): """ diff --git a/tests/transforms/inverse/test_inverse_dict.py b/tests/transforms/inverse/test_inverse_dict.py index 1998801663..6dfe8c8816 100644 --- a/tests/transforms/inverse/test_inverse_dict.py +++ b/tests/transforms/inverse/test_inverse_dict.py @@ -39,7 +39,6 @@ def setUp(self): self.postprocessing = Compose( [ - Lambdad(self.pred, func=lambda x: x), # tests that added postprocess transforms don't confuse Invertd Invertd(self.pred, transform=self.preprocessing, orig_keys=self.key), ] ) From 8700f3a3e54ccbd7705ca540612b5da690262334 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Apr 2025 18:29:43 +0000 Subject: [PATCH 07/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/transforms/inverse/test_inverse_dict.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/transforms/inverse/test_inverse_dict.py b/tests/transforms/inverse/test_inverse_dict.py index 6dfe8c8816..241e6561d7 100644 --- a/tests/transforms/inverse/test_inverse_dict.py +++ b/tests/transforms/inverse/test_inverse_dict.py @@ -20,7 +20,6 @@ from monai.data import DataLoader, Dataset, MetaTensor, ThreadDataLoader, create_test_image_2d from monai.engines.evaluator import SupervisedEvaluator from monai.transforms import Compose, EnsureChannelFirstd, Invertd, Spacingd -from monai.transforms.utility.dictionary import Lambdad from monai.utils.enums import CommonKeys from tests.test_utils import TEST_DEVICES, SkipIfNoModule From cd130adb37ed2dd76cb9b0281566cf6522a5bd1c Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 9 Apr 2025 19:39:05 +0100 Subject: [PATCH 08/12] Fix again Signed-off-by: Eric Kerfoot --- tests/transforms/inverse/test_inverse_dict.py | 2 +- tests/transforms/inverse/test_traceable_transform.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/transforms/inverse/test_inverse_dict.py b/tests/transforms/inverse/test_inverse_dict.py index 6dfe8c8816..821b76f759 100644 --- a/tests/transforms/inverse/test_inverse_dict.py +++ b/tests/transforms/inverse/test_inverse_dict.py @@ -72,8 +72,8 @@ def test_simple_processing(self, device): post[self.pred].shape, (1, *self.orig_size), "Result does not have same shape as original input" ) - @SkipIfNoModule("ignite") @parameterized.expand(product(sum(TEST_DEVICES, []), [True, False])) + @SkipIfNoModule("ignite") def test_workflow(self, device, use_threads): """ This tests the interaction between pre and postprocesing transform sequences being executed in parallel. diff --git a/tests/transforms/inverse/test_traceable_transform.py b/tests/transforms/inverse/test_traceable_transform.py index 6a499b2dd9..8ee7c9e62f 100644 --- a/tests/transforms/inverse/test_traceable_transform.py +++ b/tests/transforms/inverse/test_traceable_transform.py @@ -45,13 +45,13 @@ def test_default(self): self.assertEqual(len(data[expected_key]), 2) self.assertEqual(data[expected_key][-1]["class"], "_TraceTest") - with self.assertRaises(IndexError): + with self.assertRaises(ValueError): a.pop({"test": "test"}) # no stack in the data data = a.pop(data) data = a.pop(data) self.assertEqual(data[expected_key], []) - with self.assertRaises(IndexError): # no more items + with self.assertRaises(ValueError): # no more items a.pop(data) From 668ad99f49a7498af6be499f81af4c429ad48c8e Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Wed, 9 Apr 2025 19:47:24 +0100 Subject: [PATCH 09/12] Formatting Signed-off-by: Eric Kerfoot --- tests/transforms/inverse/test_inverse_dict.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/transforms/inverse/test_inverse_dict.py b/tests/transforms/inverse/test_inverse_dict.py index 10f5177862..466be7411c 100644 --- a/tests/transforms/inverse/test_inverse_dict.py +++ b/tests/transforms/inverse/test_inverse_dict.py @@ -36,11 +36,7 @@ def setUp(self): self.preprocessing = Compose([EnsureChannelFirstd(self.key), Spacingd(self.key, pixdim=[self.new_pixdim] * 2)]) - self.postprocessing = Compose( - [ - Invertd(self.pred, transform=self.preprocessing, orig_keys=self.key), - ] - ) + self.postprocessing = Compose([Invertd(self.pred, transform=self.preprocessing, orig_keys=self.key)]) @parameterized.expand(TEST_DEVICES) def test_simple_processing(self, device): From 3479f04e8cddd58fd8c07b0e41cd391c6385a1cd Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Thu, 10 Apr 2025 11:01:58 +0100 Subject: [PATCH 10/12] Improve the picklability of TraceableTransform Signed-off-by: Eric Kerfoot --- monai/transforms/inverse.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 83ec7f1ab2..cd95ccfa73 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -72,6 +72,7 @@ class TraceableTransform(Transform): """ def _init_trace_threadlocal(self): + """Create a `_tracing` instance member to store the thread-local tracing state value.""" # needed since this class is meant to be a trait with no constructor if not hasattr(self, "_tracing"): self._tracing = threading.local() @@ -81,6 +82,12 @@ def _init_trace_threadlocal(self): if not hasattr(self._tracing, "value"): self._tracing.value = MONAIEnvVars.trace_transform() != "0" + def __getstate__(self): + """When pickling, delete the `_tracing` member first, if present, since it's not picklable.""" + if hasattr(self, "_tracing"): + del self._tracing # this can always be re-created with the default value + return super().__getstate__() + @property def tracing(self) -> bool: """ From 60d946b47e8a95c421e51ec4c76dd9b56776d069 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Thu, 10 Apr 2025 11:46:44 +0100 Subject: [PATCH 11/12] __getstate__ update for compatibility Signed-off-by: Eric Kerfoot --- monai/transforms/inverse.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index cd95ccfa73..daece5df0c 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -83,10 +83,11 @@ def _init_trace_threadlocal(self): self._tracing.value = MONAIEnvVars.trace_transform() != "0" def __getstate__(self): - """When pickling, delete the `_tracing` member first, if present, since it's not picklable.""" - if hasattr(self, "_tracing"): - del self._tracing # this can always be re-created with the default value - return super().__getstate__() + """When pickling, remove the `_tracing` member from the output, if present, since it's not picklable.""" + _dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object + _slots = getattr(self, "__slots__", None) + _dict.pop("_tracing", None) # remove tracing + return _dict if _slots is None else (_dict, _slots) @property def tracing(self) -> bool: From 4be65d2f0a6208087b171041b76adeadf97bd276 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot Date: Thu, 10 Apr 2025 12:14:54 +0100 Subject: [PATCH 12/12] Fix Signed-off-by: Eric Kerfoot --- monai/transforms/inverse.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index daece5df0c..2f57f4614a 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -85,9 +85,9 @@ def _init_trace_threadlocal(self): def __getstate__(self): """When pickling, remove the `_tracing` member from the output, if present, since it's not picklable.""" _dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object - _slots = getattr(self, "__slots__", None) + _slots = {k: getattr(self, k) for k in getattr(self, "__slots__", [])} _dict.pop("_tracing", None) # remove tracing - return _dict if _slots is None else (_dict, _slots) + return _dict if len(_slots) == 0 else (_dict, _slots) @property def tracing(self) -> bool: