18
18
import numpy as np
19
19
import tensorflow as tf
20
20
21
- from tensorflow_addons .image import transform_ops
22
21
from skimage import transform
23
22
23
+ from tensorflow_addons .image import transform_ops
24
+ from tensorflow_addons .utils import test_utils
25
+
24
26
_DTYPES = {
25
27
tf .dtypes .uint8 ,
26
28
tf .dtypes .int32 ,
@@ -322,11 +324,13 @@ def test_unknown_shape():
322
324
np .testing .assert_equal (image .numpy (), fn (image ).numpy ())
323
325
324
326
325
- # TODO: Parameterize on dtypes
326
327
@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
327
- def test_shear_x ():
328
- image = np .random .randint (low = 0 , high = 255 , size = (4 , 4 , 3 ), dtype = np .uint8 )
329
- color = tf .constant ([255 , 0 , 255 ], tf .uint8 )
328
+ @pytest .mark .parametrize ("dtype" , _DTYPES - {tf .dtypes .float16 })
329
+ def test_shear_x (dtype ):
330
+ image = np .random .randint (low = 0 , high = 255 , size = (4 , 4 , 3 )).astype (
331
+ dtype .as_numpy_dtype
332
+ )
333
+ color = tf .constant ([255 , 0 , 255 ], tf .int32 )
330
334
level = tf .random .uniform (shape = (), minval = 0 , maxval = 1 )
331
335
332
336
tf_image = tf .constant (image )
@@ -344,11 +348,13 @@ def test_shear_x():
344
348
np .testing .assert_equal (sheared_img .numpy (), expected_img )
345
349
346
350
347
- # TODO: Parameterize on dtypes
348
351
@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
349
- def test_shear_y ():
350
- image = np .random .randint (low = 0 , high = 255 , size = (4 , 4 , 3 ), dtype = np .uint8 )
351
- color = tf .constant ([255 , 0 , 255 ], tf .dtypes .uint8 )
352
+ @pytest .mark .parametrize ("dtype" , _DTYPES - {tf .dtypes .float16 })
353
+ def test_shear_y (dtype ):
354
+ image = np .random .randint (low = 0 , high = 255 , size = (4 , 4 , 3 )).astype (
355
+ dtype .as_numpy_dtype
356
+ )
357
+ color = tf .constant ([255 , 0 , 255 ], tf .int32 )
352
358
level = tf .random .uniform (shape = (), minval = 0 , maxval = 1 )
353
359
354
360
tf_image = tf .constant (image )
@@ -363,4 +369,4 @@ def test_shear_y():
363
369
mask = np .where (expected_img == - 1 )
364
370
expected_img [mask [0 ], mask [1 ], :] = color
365
371
366
- np . testing . assert_equal (sheared_img .numpy (), expected_img )
372
+ test_utils . assert_allclose_according_to_type (sheared_img .numpy (), expected_img )
0 commit comments