Skip to content

Commit 03d02c1

Browse files
authored
Fix unwrap dtype (#2121)
1 parent 08741c9 commit 03d02c1

File tree

5 files changed

+48
-20
lines changed

5 files changed

+48
-20
lines changed

tensorflow_addons/image/tests/transform_ops_test.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
import numpy as np
1919
import tensorflow as tf
2020

21-
from tensorflow_addons.image import transform_ops
2221
from skimage import transform
2322

23+
from tensorflow_addons.image import transform_ops
24+
from tensorflow_addons.utils import test_utils
25+
2426
_DTYPES = {
2527
tf.dtypes.uint8,
2628
tf.dtypes.int32,
@@ -322,11 +324,13 @@ def test_unknown_shape():
322324
np.testing.assert_equal(image.numpy(), fn(image).numpy())
323325

324326

325-
# TODO: Parameterize on dtypes
326327
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
327-
def test_shear_x():
328-
image = np.random.randint(low=0, high=255, size=(4, 4, 3), dtype=np.uint8)
329-
color = tf.constant([255, 0, 255], tf.uint8)
328+
@pytest.mark.parametrize("dtype", _DTYPES - {tf.dtypes.float16})
329+
def test_shear_x(dtype):
330+
image = np.random.randint(low=0, high=255, size=(4, 4, 3)).astype(
331+
dtype.as_numpy_dtype
332+
)
333+
color = tf.constant([255, 0, 255], tf.int32)
330334
level = tf.random.uniform(shape=(), minval=0, maxval=1)
331335

332336
tf_image = tf.constant(image)
@@ -344,11 +348,13 @@ def test_shear_x():
344348
np.testing.assert_equal(sheared_img.numpy(), expected_img)
345349

346350

347-
# TODO: Parameterize on dtypes
348351
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
349-
def test_shear_y():
350-
image = np.random.randint(low=0, high=255, size=(4, 4, 3), dtype=np.uint8)
351-
color = tf.constant([255, 0, 255], tf.dtypes.uint8)
352+
@pytest.mark.parametrize("dtype", _DTYPES - {tf.dtypes.float16})
353+
def test_shear_y(dtype):
354+
image = np.random.randint(low=0, high=255, size=(4, 4, 3)).astype(
355+
dtype.as_numpy_dtype
356+
)
357+
color = tf.constant([255, 0, 255], tf.int32)
352358
level = tf.random.uniform(shape=(), minval=0, maxval=1)
353359

354360
tf_image = tf.constant(image)
@@ -363,4 +369,4 @@ def test_shear_y():
363369
mask = np.where(expected_img == -1)
364370
expected_img[mask[0], mask[1], :] = color
365371

366-
np.testing.assert_equal(sheared_img.numpy(), expected_img)
372+
test_utils.assert_allclose_according_to_type(sheared_img.numpy(), expected_img)

tensorflow_addons/image/tests/translate_ops_test.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
# ==============================================================================
1515
"""Tests for translate ops."""
1616

17+
import numpy as np
1718
import pytest
19+
import scipy
1820
import tensorflow as tf
19-
import numpy as np
2021

21-
from tensorflow_addons.image import translate_ops
2222
from PIL import Image
2323

24+
from tensorflow_addons.image import translate_ops
25+
from tensorflow_addons.utils import test_utils
26+
2427
_DTYPES = {
2528
tf.dtypes.uint8,
2629
tf.dtypes.int32,
@@ -52,7 +55,6 @@ def test_translations_to_projective_transforms():
5255
np.testing.assert_equal(transform.numpy(), [[1, 0, 1, 0, 1, 1, 0, 0]])
5356

5457

55-
# TODO: Parameterize on dtypes
5658
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
5759
def test_translate_xy():
5860
image = np.random.randint(low=0, high=255, size=(4, 4, 3), dtype=np.uint8)
@@ -74,3 +76,23 @@ def test_translate_xy():
7476
)
7577

7678
np.testing.assert_equal(translated.numpy(), expected)
79+
80+
81+
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
82+
@pytest.mark.parametrize("dtype", _DTYPES - {tf.dtypes.float16})
83+
def test_translate_xy_scalar_replace(dtype):
84+
image = np.random.randint(low=0, high=128, size=(4, 4, 3)).astype(
85+
dtype.as_numpy_dtype
86+
)
87+
translate_to = np.random.randint(low=0, high=4, size=(2,))
88+
result = translate_ops.translate_xy(
89+
image=image, translate_to=translate_to, replace=1
90+
)
91+
expected = scipy.ndimage.shift(
92+
input=image,
93+
shift=(translate_to[1], translate_to[0], 0),
94+
order=0,
95+
mode="constant",
96+
cval=1,
97+
)
98+
test_utils.assert_allclose_according_to_type(result.numpy(), expected)

tensorflow_addons/image/transform_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,11 @@ def rotate(
308308
return img_utils.from_4D_image(output, original_ndims)
309309

310310

311-
def shear_x(image: TensorLike, level: float, replace: int) -> TensorLike:
311+
def shear_x(image: TensorLike, level: float, replace: TensorLike) -> TensorLike:
312312
"""Perform shear operation on an image (x-axis).
313313
314314
Args:
315-
image: A 3D image Tensor.
315+
image: A 3D image `Tensor`.
316316
level: A float denoting shear element along y-axis
317317
replace: A one or three value 1D tensor to fill empty pixels.
318318
Returns:
@@ -327,7 +327,7 @@ def shear_x(image: TensorLike, level: float, replace: int) -> TensorLike:
327327
return unwrap(image, replace)
328328

329329

330-
def shear_y(image: TensorLike, level: float, replace: int) -> TensorLike:
330+
def shear_y(image: TensorLike, level: float, replace: TensorLike) -> TensorLike:
331331
"""Perform shear operation on an image (y-axis).
332332
333333
Args:

tensorflow_addons/image/translate_ops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,13 @@ def translate(
107107

108108

109109
def translate_xy(
110-
image: TensorLike, translate_to: TensorLike, replace: int
110+
image: TensorLike, translate_to: TensorLike, replace: TensorLike
111111
) -> TensorLike:
112112
"""Translates image in X or Y dimension.
113113
114114
Args:
115115
image: A 3D image `Tensor`.
116-
translate_to: A 1D `Tensor` to translate [x, y]
116+
translate_to: A 1D `Tensor` to translate `[x, y]`.
117117
replace: A one or three value 1D `Tensor` to fill empty pixels.
118118
Returns:
119119
Translated image along X or Y axis, with space outside image

tensorflow_addons/image/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@ def unwrap(image, replace):
134134
# Find all pixels where the last channel is zero.
135135
alpha_channel = flattened_image[:, 3]
136136

137-
replace = tf.constant(replace, tf.uint8)
137+
replace = tf.cast(replace, image.dtype)
138138
if tf.rank(replace) == 0:
139139
replace = tf.expand_dims(replace, 0)
140140
replace = tf.concat([replace, replace, replace], 0)
141-
replace = tf.concat([replace, tf.ones([1], dtype=image.dtype)], 0)
141+
replace = tf.concat([replace, tf.ones([1], dtype=replace.dtype)], 0)
142142

143143
# Where they are zero, fill them in with 'replace'.
144144
cond = tf.equal(alpha_channel, 1)

0 commit comments

Comments
 (0)