diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 4a068666d4..7bc2cb38fa 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -124,7 +124,7 @@ jobs: strategy: fail-fast: false matrix: - pytorch-version: ['2.4.1', '2.5.1', '2.6.0'] # FIXME: add 'latest' back once PyTorch 2.7 issues are resolved + pytorch-version: ['2.4.1', '2.5.1', '2.6.0', 'latest'] timeout-minutes: 40 steps: - uses: actions/checkout@v4 diff --git a/docs/requirements.txt b/docs/requirements.txt index 924a9af1f9..c5cf0b5baa 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ -f https://download.pytorch.org/whl/cpu/torch-2.3.0%2Bcpu-cp39-cp39-linux_x86_64.whl -torch>=2.4.1, <2.7.0 +torch>=2.4.1 pytorch-ignite==0.4.11 numpy>=1.20 itk>=5.2 diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 2d39dfdbc1..7b4333be36 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -11,6 +11,7 @@ from __future__ import annotations +import sys from collections.abc import Sequence import torch @@ -526,6 +527,11 @@ def forward( ValueError: When affine and image batch dimension differ. """ + + # In some cases it's necessary to convert inputs to grid_sample from float64 to float32 to work around known + # issues with PyTorch, see https://github.com/Project-MONAI/MONAI/pull/8429 + convert_f32 = sys.platform == "win32" and src.dtype == torch.float64 and src.device == torch.device("cpu") + # validate `theta` if not isinstance(theta, torch.Tensor): raise TypeError(f"theta must be torch.Tensor but is {type(theta).__name__}.") @@ -582,11 +588,17 @@ def forward( ) grid = nn.functional.affine_grid(theta=theta[:, :sr], size=list(dst_size), align_corners=self.align_corners) + + _input = src.contiguous() + if convert_f32: + _input = _input.to(torch.float32) + grid = grid.to(torch.float32) + dst = nn.functional.grid_sample( - input=src.contiguous(), - grid=grid, - mode=self.mode, - padding_mode=self.padding_mode, - align_corners=self.align_corners, + input=_input, grid=grid, mode=self.mode, padding_mode=self.padding_mode, align_corners=self.align_corners ) + + if convert_f32: + dst = dst.to(torch.float64) + return dst diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a75bb390cd..2bb81863b2 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -14,6 +14,7 @@ from __future__ import annotations +import sys import warnings from collections.abc import Callable, Sequence from copy import deepcopy @@ -2106,13 +2107,30 @@ def __call__( if self.norm_coords: for i, dim in enumerate(img_t.shape[sr + 1 : 0 : -1]): grid_t[0, ..., i] *= 2.0 / max(2, dim) + + # In some cases it's necessary to convert inputs to grid_sample from float64 to float32 to work around known + # issues with PyTorch, see https://github.com/Project-MONAI/MONAI/pull/8429 + convert_f32 = ( + sys.platform == "win32" and img_t.dtype == torch.float64 and img_t.device == torch.device("cpu") + ) + + _img_t = img_t.unsqueeze(0) + + if convert_f32: + _img_t = _img_t.to(torch.float32) + grid_t = grid_t.to(torch.float32) + out = torch.nn.functional.grid_sample( - img_t.unsqueeze(0), + _img_t, grid_t, mode=_interp_mode, padding_mode=_padding_mode, align_corners=None if _align_corners == TraceKeys.NONE else _align_corners, # type: ignore )[0] + + if convert_f32: + out = out.to(torch.float64) + out_val, *_ = convert_to_dst_type(out, dst=img, dtype=np.float32) return out_val diff --git a/pyproject.toml b/pyproject.toml index 863b0fd2be..76b26731bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ requires = [ "wheel", "setuptools", - "torch>=2.4.1, <2.7.0", + "torch>=2.4.1", "ninja", "packaging" ] diff --git a/requirements-dev.txt b/requirements-dev.txt index 75aefaca99..87840556ee 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -20,7 +20,7 @@ pyflakes black>=25.1.0 isort>=5.1, <6.0 ruff -pytype>=2020.6.1; platform_system != "Windows" +pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows" types-setuptools mypy>=1.5.0, <1.12.0 ninja diff --git a/requirements.txt b/requirements.txt index 903ce0ce00..f0d1f54083 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch>=2.4.1, <2.7.0 +torch>=2.4.1 numpy>=1.24,<3.0 diff --git a/setup.cfg b/setup.cfg index 0067ab5f40..fc415e6cc0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ setup_requires = ninja packaging install_requires = - torch>=2.4.1, <2.7.0 + torch>=2.4.1 numpy>=1.24,<3.0 [options.extras_require] diff --git a/tests/integration/test_pad_collation.py b/tests/integration/test_pad_collation.py index 9d5012c9a3..a236521cd9 100644 --- a/tests/integration/test_pad_collation.py +++ b/tests/integration/test_pad_collation.py @@ -11,8 +11,10 @@ from __future__ import annotations +import os import random import unittest +from contextlib import redirect_stderr from functools import wraps import numpy as np @@ -35,7 +37,7 @@ RandZoomd, ToTensor, ) -from monai.utils import set_determinism +from monai.utils import first, set_determinism @wraps(pad_list_data_collate) @@ -97,8 +99,9 @@ def test_pad_collation(self, t_type, collate_method, transform): # Default collation should raise an error loader_fail = DataLoader(dataset, batch_size=10) with self.assertRaises(RuntimeError): - for _ in loader_fail: - pass + # stifle PyTorch error reporting, we expect failure so don't need to look at it + with open(os.devnull) as f, redirect_stderr(f): + _ = first(loader_fail) # Padded collation shouldn't loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method) diff --git a/tests/lazy_transforms_utils.py b/tests/lazy_transforms_utils.py index 3a737df201..29e755aaab 100644 --- a/tests/lazy_transforms_utils.py +++ b/tests/lazy_transforms_utils.py @@ -11,6 +11,7 @@ from __future__ import annotations +import sys from copy import deepcopy from monai.data import MetaTensor, set_track_meta @@ -62,6 +63,13 @@ def test_resampler_lazy( resampler.set_random_state(seed=seed) set_track_meta(True) resampler.lazy = True + + # FIXME: this is a fix for https://github.com/Project-MONAI/MONAI/pull/8429, remove when PyTorch has + # fixed the underlying issue + if sys.platform == "win32": + atol = 1e-4 + rtol = 1e-4 + pending_output = resampler(**deepcopy(call_param)) if output_idx is not None: expected_output, pending_output = (expected_output[output_idx], pending_output[output_idx])