Skip to content

Commit 7bba561

Browse files
authored
Allow strings for interpolation param in resize transforms (#9461)
1 parent 4d4e406 commit 7bba561

8 files changed

Lines changed: 211 additions & 135 deletions

File tree

docs/source/transforms.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ Functionals
342342
v2.functional.perspective
343343
v2.functional.elastic
344344

345+
.. autoclass:: torchvision.transforms.v2.InterpolationMode
346+
:members:
347+
345348
Color
346349
^^^^^
347350

test/test_transforms_v2.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ def adapt_fill(value, *, dtype):
495495
transforms.InterpolationMode.BICUBIC,
496496
transforms.InterpolationMode.LANCZOS,
497497
]
498+
INTERPOLATION_MODES_STR = ["nearest", "nearest-exact", "bilinear", "bicubic", "lanczos"]
498499

499500

500501
def reference_affine_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
@@ -885,7 +886,7 @@ def _check_output_size(self, input, output, *, size, max_size):
885886
@pytest.mark.parametrize("size", OUTPUT_SIZES)
886887
# `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
887888
# The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
888-
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
889+
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES_STR) - {"nearest"})
889890
@pytest.mark.parametrize("use_max_size", [True, False])
890891
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
891892
def test_image_correctness(self, size, interpolation, use_max_size, fn):
@@ -898,7 +899,7 @@ def test_image_correctness(self, size, interpolation, use_max_size, fn):
898899
expected = F.to_image(F.resize(F.to_pil_image(image), size=size, interpolation=interpolation, **max_size_kwarg))
899900

900901
self._check_output_size(image, actual, size=size, **max_size_kwarg)
901-
atol = 2 if interpolation is transforms.InterpolationMode.LANCZOS else 1
902+
atol = 2 if interpolation == "lanczos" else 1
902903
torch.testing.assert_close(actual, expected, atol=atol, rtol=0)
903904

904905
def _reference_resize_bounding_boxes(self, bounding_boxes, format, *, size, max_size=None):
@@ -1096,6 +1097,26 @@ def test_interpolation_int(self, interpolation, make_input):
10961097

10971098
assert_equal(actual, expected)
10981099

1100+
@pytest.mark.parametrize(
1101+
"interpolation_str, interpolation_enum",
1102+
[
1103+
("nearest", transforms.InterpolationMode.NEAREST),
1104+
("nearest-exact", transforms.InterpolationMode.NEAREST_EXACT),
1105+
("bilinear", transforms.InterpolationMode.BILINEAR),
1106+
("bicubic", transforms.InterpolationMode.BICUBIC),
1107+
("lanczos", transforms.InterpolationMode.LANCZOS),
1108+
],
1109+
)
1110+
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
1111+
@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
1112+
def test_interpolation_str(self, interpolation_str, interpolation_enum, fn, make_input):
1113+
input = make_input(self.INPUT_SIZE)
1114+
1115+
expected = fn(input, size=self.OUTPUT_SIZES[0], interpolation=interpolation_enum, antialias=True)
1116+
actual = fn(input, size=self.OUTPUT_SIZES[0], interpolation=interpolation_str, antialias=True)
1117+
1118+
assert_equal(actual, expected)
1119+
10991120
def test_transform_unknown_size_error(self):
11001121
with pytest.raises(ValueError, match="size can be an integer, a sequence of one or two integers, or None"):
11011122
transforms.Resize(size=object())
@@ -1554,9 +1575,7 @@ def test_transform(self, make_input, device):
15541575
@pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"])
15551576
@pytest.mark.parametrize("shear", _CORRECTNESS_AFFINE_KWARGS["shear"])
15561577
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1557-
@pytest.mark.parametrize(
1558-
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
1559-
)
1578+
@pytest.mark.parametrize("interpolation", ["nearest", "bilinear"])
15601579
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
15611580
def test_functional_image_correctness(self, angle, translate, scale, shear, center, interpolation, fill):
15621581
image = make_image(dtype=torch.uint8, device="cpu")
@@ -1587,12 +1606,10 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent
15871606
)
15881607

15891608
mae = (actual.float() - expected.float()).abs().mean()
1590-
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
1609+
assert mae < 2 if interpolation == "nearest" else 8
15911610

15921611
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
1593-
@pytest.mark.parametrize(
1594-
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
1595-
)
1612+
@pytest.mark.parametrize("interpolation", ["nearest", "bilinear"])
15961613
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
15971614
@pytest.mark.parametrize("seed", list(range(5)))
15981615
def test_transform_image_correctness(self, center, interpolation, fill, seed):
@@ -1611,7 +1628,7 @@ def test_transform_image_correctness(self, center, interpolation, fill, seed):
16111628
expected = F.to_image(transform(F.to_pil_image(image)))
16121629

16131630
mae = (actual.float() - expected.float()).abs().mean()
1614-
assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8
1631+
assert mae < 2 if interpolation == "nearest" else 8
16151632

16161633
def _compute_affine_matrix(self, *, angle, translate, scale, shear, center):
16171634
rot = math.radians(angle)
@@ -2142,9 +2159,7 @@ def test_transform(self, make_input, device):
21422159

21432160
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
21442161
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
2145-
@pytest.mark.parametrize(
2146-
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
2147-
)
2162+
@pytest.mark.parametrize("interpolation", ["nearest", "bilinear"])
21482163
@pytest.mark.parametrize("expand", [False, True])
21492164
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
21502165
def test_functional_image_correctness(self, angle, center, interpolation, expand, fill):
@@ -2160,12 +2175,10 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand
21602175
)
21612176

21622177
mae = (actual.float() - expected.float()).abs().mean()
2163-
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
2178+
assert mae < 1 if interpolation == "nearest" else 6
21642179

21652180
@pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"])
2166-
@pytest.mark.parametrize(
2167-
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
2168-
)
2181+
@pytest.mark.parametrize("interpolation", ["nearest", "bilinear"])
21692182
@pytest.mark.parametrize("expand", [False, True])
21702183
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
21712184
@pytest.mark.parametrize("seed", list(range(5)))
@@ -2189,7 +2202,7 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill,
21892202
expected = F.to_image(transform(F.to_pil_image(image)))
21902203

21912204
mae = (actual.float() - expected.float()).abs().mean()
2192-
assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6
2205+
assert mae < 1 if interpolation == "nearest" else 6
21932206

21942207
def _compute_output_canvas_size(self, *, expand, canvas_size, affine_matrix):
21952208
if not expand:
@@ -4150,6 +4163,9 @@ class TestAutoAugmentTransforms:
41504163
# rotate, are tested in their respective classes. The rest of the tests here are mostly smoke tests.
41514164

41524165
def _reference_shear_translate(self, image, *, transform_id, magnitude, interpolation, fill):
4166+
if isinstance(interpolation, str):
4167+
interpolation = transforms.InterpolationMode(interpolation)
4168+
41534169
if isinstance(image, PIL.Image.Image):
41544170
input = image
41554171
else:
@@ -4173,9 +4189,7 @@ def _reference_shear_translate(self, image, *, transform_id, magnitude, interpol
41734189

41744190
@pytest.mark.parametrize("transform_id", ["ShearX", "ShearY", "TranslateX", "TranslateY"])
41754191
@pytest.mark.parametrize("magnitude", [0.3, -0.2, 0.0])
4176-
@pytest.mark.parametrize(
4177-
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
4178-
)
4192+
@pytest.mark.parametrize("interpolation", ["nearest", "bilinear"])
41794193
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
41804194
@pytest.mark.parametrize("input_type", ["Tensor", "PIL"])
41814195
def test_correctness_shear_translate(self, transform_id, magnitude, interpolation, fill, input_type):
@@ -4208,7 +4222,7 @@ def test_correctness_shear_translate(self, transform_id, magnitude, interpolatio
42084222

42094223
if "Shear" in transform_id and input_type == "Tensor":
42104224
mae = (actual.float() - expected.float()).abs().mean()
4211-
assert mae < (12 if interpolation is transforms.InterpolationMode.NEAREST else 5)
4225+
assert mae < (12 if interpolation == "nearest" else 5)
42124226
else:
42134227
assert_close(actual, expected, rtol=0, atol=1)
42144228

@@ -4537,7 +4551,7 @@ def test_transform(self, param, value, make_input):
45374551

45384552
# `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
45394553
# The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
4540-
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES) - {transforms.InterpolationMode.NEAREST})
4554+
@pytest.mark.parametrize("interpolation", set(INTERPOLATION_MODES_STR) - {"nearest"})
45414555
def test_functional_image_correctness(self, interpolation):
45424556
image = make_image(self.INPUT_SIZE, dtype=torch.uint8)
45434557

@@ -4550,9 +4564,7 @@ def test_functional_image_correctness(self, interpolation):
45504564
)
45514565
)
45524566

4553-
torch.testing.assert_close(
4554-
actual, expected, atol=2 if interpolation is transforms.InterpolationMode.LANCZOS else 1, rtol=0
4555-
)
4567+
torch.testing.assert_close(actual, expected, atol=2 if interpolation == "lanczos" else 1, rtol=0)
45564568

45574569
def _reference_resized_crop_bounding_boxes(self, bounding_boxes, *, top, left, height, width, size):
45584570
new_height, new_width = size
@@ -5257,9 +5269,7 @@ def test_transform_error(self, distortion_scale):
52575269
transforms.RandomPerspective(distortion_scale=distortion_scale)
52585270

52595271
@pytest.mark.parametrize("coefficients", COEFFICIENTS)
5260-
@pytest.mark.parametrize(
5261-
"interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]
5262-
)
5272+
@pytest.mark.parametrize("interpolation", ["nearest", "bilinear"])
52635273
@pytest.mark.parametrize("fill", CORRECTNESS_FILLS)
52645274
def test_image_functional_correctness(self, coefficients, interpolation, fill):
52655275
image = make_image(dtype=torch.uint8, device="cpu")
@@ -5278,7 +5288,7 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill):
52785288
)
52795289
)
52805290

5281-
if interpolation is transforms.InterpolationMode.BILINEAR:
5291+
if interpolation == "bilinear":
52825292
abs_diff = (actual.float() - expected.float()).abs()
52835293
assert (abs_diff > 1).float().mean() < 7e-2
52845294
mae = abs_diff.mean()

torchvision/transforms/v2/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
1+
from torchvision.transforms import AutoAugmentPolicy # usort: skip
2+
from torchvision.transforms.functional import InterpolationMode # usort: skip
23

34
from . import functional # usort: skip
45

torchvision/transforms/v2/_auto_augment.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torchvision import transforms as _transforms, tv_tensors
99
from torchvision.transforms import _functional_tensor as _FT
1010
from torchvision.transforms.v2 import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
11-
from torchvision.transforms.v2.functional._geometry import _check_interpolation
1211
from torchvision.transforms.v2.functional._meta import get_size
1312
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT
1413

@@ -22,11 +21,11 @@ class _AutoAugmentBase(Transform):
2221
def __init__(
2322
self,
2423
*,
25-
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
24+
interpolation: Union[str, InterpolationMode, int] = "nearest",
2625
fill: Union[_FillType, dict[Union[type, str], _FillType]] = None,
2726
) -> None:
2827
super().__init__()
29-
self.interpolation = _check_interpolation(interpolation)
28+
self.interpolation = interpolation
3029
self.fill = fill
3130
self._fill = _setup_fill_arg(fill)
3231

@@ -91,7 +90,7 @@ def _apply_image_or_video_transform(
9190
image: ImageOrVideo,
9291
transform_id: str,
9392
magnitude: float,
94-
interpolation: Union[InterpolationMode, int],
93+
interpolation: Union[str, InterpolationMode, int],
9594
fill: dict[Union[type, str], _FillTypeJIT],
9695
) -> ImageOrVideo:
9796
# Note: this cast is wrong and is only here to make mypy happy (it disagrees with torchscript)
@@ -188,9 +187,13 @@ class AutoAugment(_AutoAugmentBase):
188187
Args:
189188
policy (AutoAugmentPolicy, optional): Desired policy enum defined by
190189
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
191-
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
192-
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
193-
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
190+
interpolation (str or InterpolationMode, optional): Desired interpolation enum defined by
191+
:class:`torchvision.transforms.v2.InterpolationMode`.
192+
Accepted string values are ``"nearest"``, ``"nearest-exact"``, ``"bilinear"``, ``"bicubic"``,
193+
``"box"``, ``"hamming"``, and ``"lanczos"``.
194+
``"box"``, ``"hamming"``, and ``"lanczos"`` are only supported for PIL images.
195+
The corresponding ``InterpolationMode`` enum values and Pillow integer
196+
constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
194197
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
195198
image. If given a number, the value is used for all bands respectively.
196199
"""
@@ -226,7 +229,7 @@ class AutoAugment(_AutoAugmentBase):
226229
def __init__(
227230
self,
228231
policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
229-
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
232+
interpolation: Union[str, InterpolationMode, int] = "nearest",
230233
fill: Union[_FillType, dict[Union[type, str], _FillType]] = None,
231234
) -> None:
232235
super().__init__(interpolation=interpolation, fill=fill)
@@ -366,9 +369,13 @@ class RandAugment(_AutoAugmentBase):
366369
must be non-negative integer. Default: 2.
367370
magnitude (int, optional): Magnitude for all the transformations.
368371
num_magnitude_bins (int, optional): The number of different magnitude values.
369-
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
370-
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
371-
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
372+
interpolation (str or InterpolationMode, optional): Desired interpolation enum defined by
373+
:class:`torchvision.transforms.v2.InterpolationMode`.
374+
Accepted string values are ``"nearest"``, ``"nearest-exact"``, ``"bilinear"``, ``"bicubic"``,
375+
``"box"``, ``"hamming"``, and ``"lanczos"``.
376+
``"box"``, ``"hamming"``, and ``"lanczos"`` are only supported for PIL images.
377+
The corresponding ``InterpolationMode`` enum values and Pillow integer
378+
constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
372379
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
373380
image. If given a number, the value is used for all bands respectively.
374381
"""
@@ -405,7 +412,7 @@ def __init__(
405412
num_ops: int = 2,
406413
magnitude: int = 9,
407414
num_magnitude_bins: int = 31,
408-
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
415+
interpolation: Union[str, InterpolationMode, int] = "nearest",
409416
fill: Union[_FillType, dict[Union[type, str], _FillType]] = None,
410417
) -> None:
411418
super().__init__(interpolation=interpolation, fill=fill)
@@ -447,9 +454,13 @@ class TrivialAugmentWide(_AutoAugmentBase):
447454
448455
Args:
449456
num_magnitude_bins (int, optional): The number of different magnitude values.
450-
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
451-
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
452-
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
457+
interpolation (str or InterpolationMode, optional): Desired interpolation enum defined by
458+
:class:`torchvision.transforms.v2.InterpolationMode`.
459+
Accepted string values are ``"nearest"``, ``"nearest-exact"``, ``"bilinear"``, ``"bicubic"``,
460+
``"box"``, ``"hamming"``, and ``"lanczos"``.
461+
``"box"``, ``"hamming"``, and ``"lanczos"`` are only supported for PIL images.
462+
The corresponding ``InterpolationMode`` enum values and Pillow integer
463+
constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
453464
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
454465
image. If given a number, the value is used for all bands respectively.
455466
"""
@@ -478,7 +489,7 @@ class TrivialAugmentWide(_AutoAugmentBase):
478489
def __init__(
479490
self,
480491
num_magnitude_bins: int = 31,
481-
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
492+
interpolation: Union[str, InterpolationMode, int] = "nearest",
482493
fill: Union[_FillType, dict[Union[type, str], _FillType]] = None,
483494
):
484495
super().__init__(interpolation=interpolation, fill=fill)
@@ -521,9 +532,13 @@ class AugMix(_AutoAugmentBase):
521532
Default is ``-1``.
522533
alpha (float, optional): The hyperparameter for the probability distributions. Default is ``1.0``.
523534
all_ops (bool, optional): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
524-
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
525-
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
526-
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
535+
interpolation (str or InterpolationMode, optional): Desired interpolation enum defined by
536+
:class:`torchvision.transforms.v2.InterpolationMode`.
537+
Accepted string values are ``"nearest"``, ``"nearest-exact"``, ``"bilinear"``, ``"bicubic"``,
538+
``"box"``, ``"hamming"``, and ``"lanczos"``.
539+
``"box"``, ``"hamming"``, and ``"lanczos"`` are only supported for PIL images.
540+
The corresponding ``InterpolationMode`` enum values and Pillow integer
541+
constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
527542
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
528543
image. If given a number, the value is used for all bands respectively.
529544
"""
@@ -559,7 +574,7 @@ def __init__(
559574
chain_depth: int = -1,
560575
alpha: float = 1.0,
561576
all_ops: bool = True,
562-
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
577+
interpolation: Union[str, InterpolationMode, int] = "bilinear",
563578
fill: Union[_FillType, dict[Union[type, str], _FillType]] = None,
564579
) -> None:
565580
super().__init__(interpolation=interpolation, fill=fill)

0 commit comments

Comments
 (0)