Skip to content

Commit ffeaede

Browse files
authored
Do not cast to float32 when dtype is floating (#2400)
1 parent 95055be commit ffeaede

File tree

2 files changed

+81
-85
lines changed

2 files changed

+81
-85
lines changed

tensorflow_addons/image/distort_image_ops.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -118,20 +118,20 @@ def _adjust_hsv_in_yiq(
118118
# https://beesbuzz.biz/code/hsv_color_transforms.php
119119
yiq = tf.constant(
120120
[[0.299, 0.596, 0.211], [0.587, -0.274, -0.523], [0.114, -0.322, 0.312]],
121-
dtype=tf.float32,
121+
dtype=image.dtype,
122122
)
123123
yiq_inverse = tf.constant(
124124
[
125125
[1.0, 1.0, 1.0],
126126
[0.95617069, -0.2726886, -1.103744],
127127
[0.62143257, -0.64681324, 1.70062309],
128128
],
129-
dtype=tf.float32,
129+
dtype=image.dtype,
130130
)
131131
vsu = scale_value * scale_saturation * tf.math.cos(delta_hue)
132132
vsw = scale_value * scale_saturation * tf.math.sin(delta_hue)
133133
hsv_transform = tf.convert_to_tensor(
134-
[[scale_value, 0, 0], [0, vsu, vsw], [0, -vsw, vsu]], dtype=tf.float32
134+
[[scale_value, 0, 0], [0, vsu, vsw], [0, -vsw, vsu]], dtype=image.dtype
135135
)
136136
transform_matrix = yiq @ hsv_transform @ yiq_inverse
137137

@@ -172,15 +172,16 @@ def adjust_hsv_in_yiq(
172172
"""
173173
with tf.name_scope(name or "adjust_hsv_in_yiq"):
174174
image = tf.convert_to_tensor(image, name="image")
175-
delta_hue = tf.cast(delta_hue, dtype=tf.float32, name="delta_hue")
176-
scale_saturation = tf.cast(
177-
scale_saturation, dtype=tf.float32, name="scale_saturation"
178-
)
179-
scale_value = tf.cast(scale_value, dtype=tf.float32, name="scale_value")
180-
181175
# Remember original dtype to so we can convert back if needed
182176
orig_dtype = image.dtype
183-
image = tf.image.convert_image_dtype(image, tf.float32)
177+
if not image.dtype.is_floating:
178+
image = tf.image.convert_image_dtype(image, tf.float32)
179+
180+
delta_hue = tf.cast(delta_hue, dtype=image.dtype, name="delta_hue")
181+
scale_saturation = tf.cast(
182+
scale_saturation, dtype=image.dtype, name="scale_saturation"
183+
)
184+
scale_value = tf.cast(scale_value, dtype=image.dtype, name="scale_value")
184185

185186
if not options.TF_ADDONS_PY_OPS:
186187
warnings.warn(

tensorflow_addons/image/tests/distort_image_ops_test.py

+70-75
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import tensorflow as tf
2121
from tensorflow_addons.image import distort_image_ops
22+
from tensorflow_addons.utils import test_utils
2223

2324

2425
def _adjust_hue_in_yiq_np(x_np, delta_h):
@@ -36,20 +37,28 @@ def _adjust_hue_in_yiq_np(x_np, delta_h):
3637
"""
3738
assert x_np.shape[-1] == 3
3839
x_v = x_np.reshape([-1, 3])
39-
y_v = np.ndarray(x_v.shape, dtype=x_v.dtype)
4040
u = np.cos(delta_h)
4141
w = np.sin(delta_h)
4242
# Projection matrix from RGB to YIQ. Numbers from wikipedia
4343
# https://en.wikipedia.org/wiki/YIQ
4444
tyiq = np.array(
4545
[[0.299, 0.587, 0.114], [0.596, -0.274, -0.322], [0.211, -0.523, 0.312]]
46-
)
47-
y_v = np.dot(x_v, tyiq.T)
46+
).astype(x_v.dtype)
47+
inverse_tyiq = np.array(
48+
[
49+
[1.0, 0.95617069, 0.62143257],
50+
[1.0, -0.2726886, -0.64681324],
51+
[1.0, -1.103744, 1.70062309],
52+
]
53+
).astype(x_v.dtype)
54+
y_v = np.dot(x_v, tyiq.T).astype(x_v.dtype)
4855
# Hue rotation matrix in YIQ space.
49-
hue_rotation = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
56+
hue_rotation = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]]).astype(
57+
x_v.dtype
58+
)
5059
y_v = np.dot(y_v, hue_rotation.T)
5160
# Projecting back to RGB space.
52-
y_v = np.dot(y_v, np.linalg.inv(tyiq).T)
61+
y_v = np.dot(y_v, inverse_tyiq.T)
5362
return y_v.reshape(x_np.shape)
5463

5564

@@ -59,41 +68,34 @@ def _adjust_hue_in_yiq_tf(x_np, delta_h):
5968
return y
6069

6170

62-
def test_adjust_random_hue_in_yiq():
63-
x_shapes = [
64-
[2, 2, 3],
65-
[4, 2, 3],
66-
[2, 4, 3],
67-
[2, 5, 3],
68-
[1000, 1, 3],
69-
]
70-
test_styles = [
71-
"all_random",
72-
"rg_same",
73-
"rb_same",
74-
"gb_same",
75-
"rgb_same",
76-
]
77-
for x_shape in x_shapes:
78-
for test_style in test_styles:
79-
x_np = np.random.rand(*x_shape) * 255.0
80-
delta_h = (np.random.rand() * 2.0 - 1.0) * np.pi
81-
if test_style == "all_random":
82-
pass
83-
elif test_style == "rg_same":
84-
x_np[..., 1] = x_np[..., 0]
85-
elif test_style == "rb_same":
86-
x_np[..., 2] = x_np[..., 0]
87-
elif test_style == "gb_same":
88-
x_np[..., 2] = x_np[..., 1]
89-
elif test_style == "rgb_same":
90-
x_np[..., 1] = x_np[..., 0]
91-
x_np[..., 2] = x_np[..., 0]
92-
else:
93-
raise AssertionError("Invalid test style: %s" % (test_style))
94-
y_np = _adjust_hue_in_yiq_np(x_np, delta_h)
95-
y_tf = _adjust_hue_in_yiq_tf(x_np, delta_h)
96-
np.testing.assert_allclose(y_tf, y_np, rtol=2e-4, atol=1e-4)
71+
@pytest.mark.parametrize(
72+
"shape", ([2, 2, 3], [4, 2, 3], [2, 4, 3], [2, 5, 3], [1000, 1, 3])
73+
)
74+
@pytest.mark.parametrize(
75+
"style", ("all_random", "rg_same", "rb_same", "gb_same", "rgb_same")
76+
)
77+
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
78+
def test_adjust_random_hue_in_yiq(shape, style, dtype):
79+
x_np = (np.random.rand(*shape) * 255.0).astype(dtype)
80+
delta_h = (np.random.rand() * 2.0 - 1.0) * np.pi
81+
if style == "all_random":
82+
pass
83+
elif style == "rg_same":
84+
x_np[..., 1] = x_np[..., 0]
85+
elif style == "rb_same":
86+
x_np[..., 2] = x_np[..., 0]
87+
elif style == "gb_same":
88+
x_np[..., 2] = x_np[..., 1]
89+
elif style == "rgb_same":
90+
x_np[..., 1] = x_np[..., 0]
91+
x_np[..., 2] = x_np[..., 0]
92+
else:
93+
raise AssertionError("Invalid test style: %s" % (style))
94+
y_np = _adjust_hue_in_yiq_np(x_np, delta_h)
95+
y_tf = _adjust_hue_in_yiq_tf(x_np, delta_h)
96+
test_utils.assert_allclose_according_to_type(
97+
y_tf, y_np, atol=1e-4, rtol=2e-4, half_rtol=0.8
98+
)
9799

98100

99101
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@@ -229,41 +231,34 @@ def _adjust_saturation_in_yiq_np(x_np, scale):
229231
return y_v
230232

231233

232-
def test_adjust_random_saturation_in_yiq():
233-
x_shapes = [
234-
[2, 2, 3],
235-
[4, 2, 3],
236-
[2, 4, 3],
237-
[2, 5, 3],
238-
[1000, 1, 3],
239-
]
240-
test_styles = [
241-
"all_random",
242-
"rg_same",
243-
"rb_same",
244-
"gb_same",
245-
"rgb_same",
246-
]
247-
for x_shape in x_shapes:
248-
for test_style in test_styles:
249-
x_np = np.random.rand(*x_shape) * 255.0
250-
scale = np.random.rand() * 2.0 - 1.0
251-
if test_style == "all_random":
252-
pass
253-
elif test_style == "rg_same":
254-
x_np[..., 1] = x_np[..., 0]
255-
elif test_style == "rb_same":
256-
x_np[..., 2] = x_np[..., 0]
257-
elif test_style == "gb_same":
258-
x_np[..., 2] = x_np[..., 1]
259-
elif test_style == "rgb_same":
260-
x_np[..., 1] = x_np[..., 0]
261-
x_np[..., 2] = x_np[..., 0]
262-
else:
263-
raise AssertionError("Invalid test style: %s" % (test_style))
264-
y_baseline = _adjust_saturation_in_yiq_np(x_np, scale)
265-
y_tf = _adjust_saturation_in_yiq_tf(x_np, scale)
266-
np.testing.assert_allclose(y_tf, y_baseline, rtol=2e-4, atol=1e-4)
234+
@pytest.mark.parametrize(
235+
"shape", ([2, 2, 3], [4, 2, 3], [2, 4, 3], [2, 5, 3], [1000, 1, 3])
236+
)
237+
@pytest.mark.parametrize(
238+
"style", ("all_random", "rg_same", "rb_same", "gb_same", "rgb_same")
239+
)
240+
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
241+
def test_adjust_random_saturation_in_yiq(shape, style, dtype):
242+
x_np = (np.random.rand(*shape) * 255.0).astype(dtype)
243+
scale = np.random.rand() * 2.0 - 1.0
244+
if style == "all_random":
245+
pass
246+
elif style == "rg_same":
247+
x_np[..., 1] = x_np[..., 0]
248+
elif style == "rb_same":
249+
x_np[..., 2] = x_np[..., 0]
250+
elif style == "gb_same":
251+
x_np[..., 2] = x_np[..., 1]
252+
elif style == "rgb_same":
253+
x_np[..., 1] = x_np[..., 0]
254+
x_np[..., 2] = x_np[..., 0]
255+
else:
256+
raise AssertionError("Invalid test style: %s" % (style))
257+
y_baseline = _adjust_saturation_in_yiq_np(x_np, scale)
258+
y_tf = _adjust_saturation_in_yiq_tf(x_np, scale)
259+
test_utils.assert_allclose_according_to_type(
260+
y_tf, y_baseline, atol=1e-4, rtol=2e-4, half_rtol=0.8
261+
)
267262

268263

269264
def test_invalid_rank():

0 commit comments

Comments
 (0)