|
14 | 14 | # ==============================================================================
|
15 | 15 | """Python layer for distort_image_ops."""
|
16 | 16 |
|
| 17 | +from typing import Optional |
| 18 | +import warnings |
| 19 | + |
17 | 20 | import tensorflow as tf
|
| 21 | + |
| 22 | +from tensorflow_addons import options |
18 | 23 | from tensorflow_addons.utils.resource_loader import LazySO
|
19 | 24 | from tensorflow_addons.utils.types import Number, TensorLike
|
20 | 25 |
|
21 |
| -from typing import Optional |
22 |
| - |
23 | 26 | _distort_image_so = LazySO("custom_ops/image/_distort_image_ops.so")
|
24 | 27 |
|
25 | 28 |
|
@@ -99,6 +102,43 @@ def random_hsv_in_yiq(
|
99 | 102 | )
|
100 | 103 |
|
101 | 104 |
|
| 105 | +def _adjust_hsv_in_yiq( |
| 106 | + image, |
| 107 | + delta_hue, |
| 108 | + scale_saturation, |
| 109 | + scale_value, |
| 110 | +): |
| 111 | + if image.shape.rank is not None and image.shape.rank < 3: |
| 112 | + raise ValueError("input must be at least 3-D.") |
| 113 | + if image.shape[-1] is not None and image.shape[-1] != 3: |
| 114 | + raise ValueError( |
| 115 | + "input must have 3 channels but instead has {}.".format(image.shape[-1]) |
| 116 | + ) |
| 117 | + # Construct hsv linear transformation matrix in YIQ space. |
| 118 | + # https://beesbuzz.biz/code/hsv_color_transforms.php |
| 119 | + yiq = tf.constant( |
| 120 | + [[0.299, 0.596, 0.211], [0.587, -0.274, -0.523], [0.114, -0.322, 0.312]], |
| 121 | + dtype=tf.float32, |
| 122 | + ) |
| 123 | + yiq_inverse = tf.constant( |
| 124 | + [ |
| 125 | + [1.0, 1.0, 1.0], |
| 126 | + [0.95617069, -0.2726886, -1.103744], |
| 127 | + [0.62143257, -0.64681324, 1.70062309], |
| 128 | + ], |
| 129 | + dtype=tf.float32, |
| 130 | + ) |
| 131 | + vsu = scale_value * scale_saturation * tf.math.cos(delta_hue) |
| 132 | + vsw = scale_value * scale_saturation * tf.math.sin(delta_hue) |
| 133 | + hsv_transform = tf.convert_to_tensor( |
| 134 | + [[scale_value, 0, 0], [0, vsu, vsw], [0, -vsw, vsu]], dtype=tf.float32 |
| 135 | + ) |
| 136 | + transform_matrix = yiq @ hsv_transform @ yiq_inverse |
| 137 | + |
| 138 | + image = image @ transform_matrix |
| 139 | + return image |
| 140 | + |
| 141 | + |
102 | 142 | def adjust_hsv_in_yiq(
|
103 | 143 | image: TensorLike,
|
104 | 144 | delta_hue: Number = 0,
|
@@ -132,13 +172,31 @@ def adjust_hsv_in_yiq(
|
132 | 172 | """
|
133 | 173 | with tf.name_scope(name or "adjust_hsv_in_yiq"):
|
134 | 174 | image = tf.convert_to_tensor(image, name="image")
|
| 175 | + delta_hue = tf.cast(delta_hue, dtype=tf.float32, name="delta_hue") |
| 176 | + scale_saturation = tf.cast( |
| 177 | + scale_saturation, dtype=tf.float32, name="scale_saturation" |
| 178 | + ) |
| 179 | + scale_value = tf.cast(scale_value, dtype=tf.float32, name="scale_value") |
135 | 180 |
|
136 | 181 | # Remember original dtype to so we can convert back if needed
|
137 | 182 | orig_dtype = image.dtype
|
138 |
| - flt_image = tf.image.convert_image_dtype(image, tf.dtypes.float32) |
| 183 | + image = tf.image.convert_image_dtype(image, tf.float32) |
139 | 184 |
|
140 |
| - rgb_altered = _distort_image_so.ops.addons_adjust_hsv_in_yiq( |
141 |
| - flt_image, delta_hue, scale_saturation, scale_value |
142 |
| - ) |
| 185 | + if not options.TF_ADDONS_PY_OPS: |
| 186 | + warnings.warn( |
| 187 | + "C++/CUDA kernel of `adjust_hsv_in_yiq` will be removed in Addons `0.13`.", |
| 188 | + DeprecationWarning, |
| 189 | + ) |
| 190 | + try: |
| 191 | + image = _distort_image_so.ops.addons_adjust_hsv_in_yiq( |
| 192 | + image, delta_hue, scale_saturation, scale_value |
| 193 | + ) |
| 194 | + except tf.errors.NotFoundError: |
| 195 | + options.warn_fallback("adjust_hsv_in_yiq") |
| 196 | + image = _adjust_hsv_in_yiq( |
| 197 | + image, delta_hue, scale_saturation, scale_value |
| 198 | + ) |
| 199 | + else: |
| 200 | + image = _adjust_hsv_in_yiq(image, delta_hue, scale_saturation, scale_value) |
143 | 201 |
|
144 |
| - return tf.image.convert_image_dtype(rgb_altered, orig_dtype) |
| 202 | + return tf.image.convert_image_dtype(image, orig_dtype) |
0 commit comments