Skip to content

Commit 95055be

Browse files
authored
Add python fallback for adjust_hsv_in_yiq (#2392)
1 parent 635c484 commit 95055be

File tree

2 files changed

+73
-14
lines changed

2 files changed

+73
-14
lines changed

tensorflow_addons/image/distort_image_ops.py

+65-7
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,15 @@
1414
# ==============================================================================
1515
"""Python layer for distort_image_ops."""
1616

17+
from typing import Optional
18+
import warnings
19+
1720
import tensorflow as tf
21+
22+
from tensorflow_addons import options
1823
from tensorflow_addons.utils.resource_loader import LazySO
1924
from tensorflow_addons.utils.types import Number, TensorLike
2025

21-
from typing import Optional
22-
2326
_distort_image_so = LazySO("custom_ops/image/_distort_image_ops.so")
2427

2528

@@ -99,6 +102,43 @@ def random_hsv_in_yiq(
99102
)
100103

101104

105+
def _adjust_hsv_in_yiq(
106+
image,
107+
delta_hue,
108+
scale_saturation,
109+
scale_value,
110+
):
111+
if image.shape.rank is not None and image.shape.rank < 3:
112+
raise ValueError("input must be at least 3-D.")
113+
if image.shape[-1] is not None and image.shape[-1] != 3:
114+
raise ValueError(
115+
"input must have 3 channels but instead has {}.".format(image.shape[-1])
116+
)
117+
# Construct hsv linear transformation matrix in YIQ space.
118+
# https://beesbuzz.biz/code/hsv_color_transforms.php
119+
yiq = tf.constant(
120+
[[0.299, 0.596, 0.211], [0.587, -0.274, -0.523], [0.114, -0.322, 0.312]],
121+
dtype=tf.float32,
122+
)
123+
yiq_inverse = tf.constant(
124+
[
125+
[1.0, 1.0, 1.0],
126+
[0.95617069, -0.2726886, -1.103744],
127+
[0.62143257, -0.64681324, 1.70062309],
128+
],
129+
dtype=tf.float32,
130+
)
131+
vsu = scale_value * scale_saturation * tf.math.cos(delta_hue)
132+
vsw = scale_value * scale_saturation * tf.math.sin(delta_hue)
133+
hsv_transform = tf.convert_to_tensor(
134+
[[scale_value, 0, 0], [0, vsu, vsw], [0, -vsw, vsu]], dtype=tf.float32
135+
)
136+
transform_matrix = yiq @ hsv_transform @ yiq_inverse
137+
138+
image = image @ transform_matrix
139+
return image
140+
141+
102142
def adjust_hsv_in_yiq(
103143
image: TensorLike,
104144
delta_hue: Number = 0,
@@ -132,13 +172,31 @@ def adjust_hsv_in_yiq(
132172
"""
133173
with tf.name_scope(name or "adjust_hsv_in_yiq"):
134174
image = tf.convert_to_tensor(image, name="image")
175+
delta_hue = tf.cast(delta_hue, dtype=tf.float32, name="delta_hue")
176+
scale_saturation = tf.cast(
177+
scale_saturation, dtype=tf.float32, name="scale_saturation"
178+
)
179+
scale_value = tf.cast(scale_value, dtype=tf.float32, name="scale_value")
135180

136181
# Remember original dtype to so we can convert back if needed
137182
orig_dtype = image.dtype
138-
flt_image = tf.image.convert_image_dtype(image, tf.dtypes.float32)
183+
image = tf.image.convert_image_dtype(image, tf.float32)
139184

140-
rgb_altered = _distort_image_so.ops.addons_adjust_hsv_in_yiq(
141-
flt_image, delta_hue, scale_saturation, scale_value
142-
)
185+
if not options.TF_ADDONS_PY_OPS:
186+
warnings.warn(
187+
"C++/CUDA kernel of `adjust_hsv_in_yiq` will be removed in Addons `0.13`.",
188+
DeprecationWarning,
189+
)
190+
try:
191+
image = _distort_image_so.ops.addons_adjust_hsv_in_yiq(
192+
image, delta_hue, scale_saturation, scale_value
193+
)
194+
except tf.errors.NotFoundError:
195+
options.warn_fallback("adjust_hsv_in_yiq")
196+
image = _adjust_hsv_in_yiq(
197+
image, delta_hue, scale_saturation, scale_value
198+
)
199+
else:
200+
image = _adjust_hsv_in_yiq(image, delta_hue, scale_saturation, scale_value)
143201

144-
return tf.image.convert_image_dtype(rgb_altered, orig_dtype)
202+
return tf.image.convert_image_dtype(image, orig_dtype)

tensorflow_addons/image/tests/distort_image_ops_test.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_invalid_rank_hsv():
101101
x_np = np.random.rand(2, 3) * 255.0
102102
delta_h = np.random.rand() * 2.0 - 1.0
103103
with pytest.raises(
104-
tf.errors.InvalidArgumentError, match="input must be at least 3-D"
104+
(tf.errors.InvalidArgumentError, ValueError), match="input must be at least 3-D"
105105
):
106106
_adjust_hue_in_yiq_tf(x_np, delta_h)
107107

@@ -111,7 +111,7 @@ def test_invalid_channels_hsv():
111111
x_np = np.random.rand(4, 2, 4) * 255.0
112112
delta_h = np.random.rand() * 2.0 - 1.0
113113
with pytest.raises(
114-
tf.errors.InvalidArgumentError,
114+
(tf.errors.InvalidArgumentError, ValueError),
115115
match="input must have 3 channels but instead has 4",
116116
):
117117
_adjust_hue_in_yiq_tf(x_np, delta_h)
@@ -190,7 +190,8 @@ def test_invalid_rank_value():
190190
scale = np.random.rand() * 2.0 - 1.0
191191
if tf.executing_eagerly():
192192
with pytest.raises(
193-
tf.errors.InvalidArgumentError, match="input must be at least 3-D"
193+
(tf.errors.InvalidArgumentError, ValueError),
194+
match="input must be at least 3-D",
194195
):
195196
_adjust_value_in_yiq_tf(x_np, scale)
196197
else:
@@ -205,7 +206,7 @@ def test_invalid_channels_value():
205206
scale = np.random.rand() * 2.0 - 1.0
206207
if tf.executing_eagerly():
207208
with pytest.raises(
208-
tf.errors.InvalidArgumentError,
209+
(tf.errors.InvalidArgumentError, ValueError),
209210
match="input must have 3 channels but instead has 4",
210211
):
211212
_adjust_value_in_yiq_tf(x_np, scale)
@@ -270,13 +271,13 @@ def test_invalid_rank():
270271
scale = np.random.rand() * 2.0 - 1.0
271272

272273
msg = "input must be at least 3-D"
273-
with pytest.raises(tf.errors.InvalidArgumentError, match=msg):
274+
with pytest.raises((tf.errors.InvalidArgumentError, ValueError), match=msg):
274275
_adjust_saturation_in_yiq_tf(x_np, scale).numpy()
275276

276277

277278
def test_invalid_channels():
278279
x_np = np.random.rand(4, 2, 4) * 255.0
279280
scale = np.random.rand() * 2.0 - 1.0
280-
msg = "input must have 3 channels but instead has 4 "
281-
with pytest.raises(tf.errors.InvalidArgumentError, match=msg):
281+
msg = "input must have 3 channels but instead has 4"
282+
with pytest.raises((tf.errors.InvalidArgumentError, ValueError), match=msg):
282283
_adjust_saturation_in_yiq_tf(x_np, scale).numpy()

0 commit comments

Comments
 (0)