@@ -495,6 +495,7 @@ def adapt_fill(value, *, dtype):
495495 transforms .InterpolationMode .BICUBIC ,
496496 transforms .InterpolationMode .LANCZOS ,
497497]
498+ INTERPOLATION_MODES_STR = ["nearest" , "nearest-exact" , "bilinear" , "bicubic" , "lanczos" ]
498499
499500
500501def reference_affine_bounding_boxes_helper (bounding_boxes , * , affine_matrix , new_canvas_size = None , clamp = True ):
@@ -885,7 +886,7 @@ def _check_output_size(self, input, output, *, size, max_size):
885886 @pytest .mark .parametrize ("size" , OUTPUT_SIZES )
886887 # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
887888 # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
888- @pytest .mark .parametrize ("interpolation" , set (INTERPOLATION_MODES ) - {transforms . InterpolationMode . NEAREST })
889+ @pytest .mark .parametrize ("interpolation" , set (INTERPOLATION_MODES_STR ) - {"nearest" })
889890 @pytest .mark .parametrize ("use_max_size" , [True , False ])
890891 @pytest .mark .parametrize ("fn" , [F .resize , transform_cls_to_functional (transforms .Resize )])
891892 def test_image_correctness (self , size , interpolation , use_max_size , fn ):
@@ -898,7 +899,7 @@ def test_image_correctness(self, size, interpolation, use_max_size, fn):
898899 expected = F .to_image (F .resize (F .to_pil_image (image ), size = size , interpolation = interpolation , ** max_size_kwarg ))
899900
900901 self ._check_output_size (image , actual , size = size , ** max_size_kwarg )
901- atol = 2 if interpolation is transforms . InterpolationMode . LANCZOS else 1
902+ atol = 2 if interpolation == "lanczos" else 1
902903 torch .testing .assert_close (actual , expected , atol = atol , rtol = 0 )
903904
904905 def _reference_resize_bounding_boxes (self , bounding_boxes , format , * , size , max_size = None ):
@@ -1096,6 +1097,26 @@ def test_interpolation_int(self, interpolation, make_input):
10961097
10971098 assert_equal (actual , expected )
10981099
1100+ @pytest .mark .parametrize (
1101+ "interpolation_str, interpolation_enum" ,
1102+ [
1103+ ("nearest" , transforms .InterpolationMode .NEAREST ),
1104+ ("nearest-exact" , transforms .InterpolationMode .NEAREST_EXACT ),
1105+ ("bilinear" , transforms .InterpolationMode .BILINEAR ),
1106+ ("bicubic" , transforms .InterpolationMode .BICUBIC ),
1107+ ("lanczos" , transforms .InterpolationMode .LANCZOS ),
1108+ ],
1109+ )
1110+ @pytest .mark .parametrize ("fn" , [F .resize , transform_cls_to_functional (transforms .Resize )])
1111+ @pytest .mark .parametrize ("make_input" , [make_image_tensor , make_image , make_video ])
1112+ def test_interpolation_str (self , interpolation_str , interpolation_enum , fn , make_input ):
1113+ input = make_input (self .INPUT_SIZE )
1114+
1115+ expected = fn (input , size = self .OUTPUT_SIZES [0 ], interpolation = interpolation_enum , antialias = True )
1116+ actual = fn (input , size = self .OUTPUT_SIZES [0 ], interpolation = interpolation_str , antialias = True )
1117+
1118+ assert_equal (actual , expected )
1119+
10991120 def test_transform_unknown_size_error (self ):
11001121 with pytest .raises (ValueError , match = "size can be an integer, a sequence of one or two integers, or None" ):
11011122 transforms .Resize (size = object ())
@@ -1554,9 +1575,7 @@ def test_transform(self, make_input, device):
15541575 @pytest .mark .parametrize ("scale" , _CORRECTNESS_AFFINE_KWARGS ["scale" ])
15551576 @pytest .mark .parametrize ("shear" , _CORRECTNESS_AFFINE_KWARGS ["shear" ])
15561577 @pytest .mark .parametrize ("center" , _CORRECTNESS_AFFINE_KWARGS ["center" ])
1557- @pytest .mark .parametrize (
1558- "interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
1559- )
1578+ @pytest .mark .parametrize ("interpolation" , ["nearest" , "bilinear" ])
15601579 @pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
15611580 def test_functional_image_correctness (self , angle , translate , scale , shear , center , interpolation , fill ):
15621581 image = make_image (dtype = torch .uint8 , device = "cpu" )
@@ -1587,12 +1606,10 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent
15871606 )
15881607
15891608 mae = (actual .float () - expected .float ()).abs ().mean ()
1590- assert mae < 2 if interpolation is transforms . InterpolationMode . NEAREST else 8
1609+ assert mae < 2 if interpolation == "nearest" else 8
15911610
15921611 @pytest .mark .parametrize ("center" , _CORRECTNESS_AFFINE_KWARGS ["center" ])
1593- @pytest .mark .parametrize (
1594- "interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
1595- )
1612+ @pytest .mark .parametrize ("interpolation" , ["nearest" , "bilinear" ])
15961613 @pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
15971614 @pytest .mark .parametrize ("seed" , list (range (5 )))
15981615 def test_transform_image_correctness (self , center , interpolation , fill , seed ):
@@ -1611,7 +1628,7 @@ def test_transform_image_correctness(self, center, interpolation, fill, seed):
16111628 expected = F .to_image (transform (F .to_pil_image (image )))
16121629
16131630 mae = (actual .float () - expected .float ()).abs ().mean ()
1614- assert mae < 2 if interpolation is transforms . InterpolationMode . NEAREST else 8
1631+ assert mae < 2 if interpolation == "nearest" else 8
16151632
16161633 def _compute_affine_matrix (self , * , angle , translate , scale , shear , center ):
16171634 rot = math .radians (angle )
@@ -2142,9 +2159,7 @@ def test_transform(self, make_input, device):
21422159
21432160 @pytest .mark .parametrize ("angle" , _CORRECTNESS_AFFINE_KWARGS ["angle" ])
21442161 @pytest .mark .parametrize ("center" , _CORRECTNESS_AFFINE_KWARGS ["center" ])
2145- @pytest .mark .parametrize (
2146- "interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
2147- )
2162+ @pytest .mark .parametrize ("interpolation" , ["nearest" , "bilinear" ])
21482163 @pytest .mark .parametrize ("expand" , [False , True ])
21492164 @pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
21502165 def test_functional_image_correctness (self , angle , center , interpolation , expand , fill ):
@@ -2160,12 +2175,10 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand
21602175 )
21612176
21622177 mae = (actual .float () - expected .float ()).abs ().mean ()
2163- assert mae < 1 if interpolation is transforms . InterpolationMode . NEAREST else 6
2178+ assert mae < 1 if interpolation == "nearest" else 6
21642179
21652180 @pytest .mark .parametrize ("center" , _CORRECTNESS_AFFINE_KWARGS ["center" ])
2166- @pytest .mark .parametrize (
2167- "interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
2168- )
2181+ @pytest .mark .parametrize ("interpolation" , ["nearest" , "bilinear" ])
21692182 @pytest .mark .parametrize ("expand" , [False , True ])
21702183 @pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
21712184 @pytest .mark .parametrize ("seed" , list (range (5 )))
@@ -2189,7 +2202,7 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill,
21892202 expected = F .to_image (transform (F .to_pil_image (image )))
21902203
21912204 mae = (actual .float () - expected .float ()).abs ().mean ()
2192- assert mae < 1 if interpolation is transforms . InterpolationMode . NEAREST else 6
2205+ assert mae < 1 if interpolation == "nearest" else 6
21932206
21942207 def _compute_output_canvas_size (self , * , expand , canvas_size , affine_matrix ):
21952208 if not expand :
@@ -4150,6 +4163,9 @@ class TestAutoAugmentTransforms:
41504163 # rotate, are tested in their respective classes. The rest of the tests here are mostly smoke tests.
41514164
41524165 def _reference_shear_translate (self , image , * , transform_id , magnitude , interpolation , fill ):
4166+ if isinstance (interpolation , str ):
4167+ interpolation = transforms .InterpolationMode (interpolation )
4168+
41534169 if isinstance (image , PIL .Image .Image ):
41544170 input = image
41554171 else :
@@ -4173,9 +4189,7 @@ def _reference_shear_translate(self, image, *, transform_id, magnitude, interpol
41734189
41744190 @pytest .mark .parametrize ("transform_id" , ["ShearX" , "ShearY" , "TranslateX" , "TranslateY" ])
41754191 @pytest .mark .parametrize ("magnitude" , [0.3 , - 0.2 , 0.0 ])
4176- @pytest .mark .parametrize (
4177- "interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
4178- )
4192+ @pytest .mark .parametrize ("interpolation" , ["nearest" , "bilinear" ])
41794193 @pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
41804194 @pytest .mark .parametrize ("input_type" , ["Tensor" , "PIL" ])
41814195 def test_correctness_shear_translate (self , transform_id , magnitude , interpolation , fill , input_type ):
@@ -4208,7 +4222,7 @@ def test_correctness_shear_translate(self, transform_id, magnitude, interpolatio
42084222
42094223 if "Shear" in transform_id and input_type == "Tensor" :
42104224 mae = (actual .float () - expected .float ()).abs ().mean ()
4211- assert mae < (12 if interpolation is transforms . InterpolationMode . NEAREST else 5 )
4225+ assert mae < (12 if interpolation == "nearest" else 5 )
42124226 else :
42134227 assert_close (actual , expected , rtol = 0 , atol = 1 )
42144228
@@ -4537,7 +4551,7 @@ def test_transform(self, param, value, make_input):
45374551
45384552 # `InterpolationMode.NEAREST` is modeled after the buggy `INTER_NEAREST` interpolation of CV2.
45394553 # The PIL equivalent of `InterpolationMode.NEAREST` is `InterpolationMode.NEAREST_EXACT`
4540- @pytest .mark .parametrize ("interpolation" , set (INTERPOLATION_MODES ) - {transforms . InterpolationMode . NEAREST })
4554+ @pytest .mark .parametrize ("interpolation" , set (INTERPOLATION_MODES_STR ) - {"nearest" })
45414555 def test_functional_image_correctness (self , interpolation ):
45424556 image = make_image (self .INPUT_SIZE , dtype = torch .uint8 )
45434557
@@ -4550,9 +4564,7 @@ def test_functional_image_correctness(self, interpolation):
45504564 )
45514565 )
45524566
4553- torch .testing .assert_close (
4554- actual , expected , atol = 2 if interpolation is transforms .InterpolationMode .LANCZOS else 1 , rtol = 0
4555- )
4567+ torch .testing .assert_close (actual , expected , atol = 2 if interpolation == "lanczos" else 1 , rtol = 0 )
45564568
45574569 def _reference_resized_crop_bounding_boxes (self , bounding_boxes , * , top , left , height , width , size ):
45584570 new_height , new_width = size
@@ -5257,9 +5269,7 @@ def test_transform_error(self, distortion_scale):
52575269 transforms .RandomPerspective (distortion_scale = distortion_scale )
52585270
52595271 @pytest .mark .parametrize ("coefficients" , COEFFICIENTS )
5260- @pytest .mark .parametrize (
5261- "interpolation" , [transforms .InterpolationMode .NEAREST , transforms .InterpolationMode .BILINEAR ]
5262- )
5272+ @pytest .mark .parametrize ("interpolation" , ["nearest" , "bilinear" ])
52635273 @pytest .mark .parametrize ("fill" , CORRECTNESS_FILLS )
52645274 def test_image_functional_correctness (self , coefficients , interpolation , fill ):
52655275 image = make_image (dtype = torch .uint8 , device = "cpu" )
@@ -5278,7 +5288,7 @@ def test_image_functional_correctness(self, coefficients, interpolation, fill):
52785288 )
52795289 )
52805290
5281- if interpolation is transforms . InterpolationMode . BILINEAR :
5291+ if interpolation == "bilinear" :
52825292 abs_diff = (actual .float () - expected .float ()).abs ()
52835293 assert (abs_diff > 1 ).float ().mean () < 7e-2
52845294 mae = abs_diff .mean ()
0 commit comments