19
19
20
20
import tensorflow as tf
21
21
from tensorflow_addons .image import distort_image_ops
22
+ from tensorflow_addons .utils import test_utils
22
23
23
24
24
25
def _adjust_hue_in_yiq_np (x_np , delta_h ):
@@ -36,20 +37,28 @@ def _adjust_hue_in_yiq_np(x_np, delta_h):
36
37
"""
37
38
assert x_np .shape [- 1 ] == 3
38
39
x_v = x_np .reshape ([- 1 , 3 ])
39
- y_v = np .ndarray (x_v .shape , dtype = x_v .dtype )
40
40
u = np .cos (delta_h )
41
41
w = np .sin (delta_h )
42
42
# Projection matrix from RGB to YIQ. Numbers from wikipedia
43
43
# https://en.wikipedia.org/wiki/YIQ
44
44
tyiq = np .array (
45
45
[[0.299 , 0.587 , 0.114 ], [0.596 , - 0.274 , - 0.322 ], [0.211 , - 0.523 , 0.312 ]]
46
- )
47
- y_v = np .dot (x_v , tyiq .T )
46
+ ).astype (x_v .dtype )
47
+ inverse_tyiq = np .array (
48
+ [
49
+ [1.0 , 0.95617069 , 0.62143257 ],
50
+ [1.0 , - 0.2726886 , - 0.64681324 ],
51
+ [1.0 , - 1.103744 , 1.70062309 ],
52
+ ]
53
+ ).astype (x_v .dtype )
54
+ y_v = np .dot (x_v , tyiq .T ).astype (x_v .dtype )
48
55
# Hue rotation matrix in YIQ space.
49
- hue_rotation = np .array ([[1.0 , 0.0 , 0.0 ], [0.0 , u , - w ], [0.0 , w , u ]])
56
+ hue_rotation = np .array ([[1.0 , 0.0 , 0.0 ], [0.0 , u , - w ], [0.0 , w , u ]]).astype (
57
+ x_v .dtype
58
+ )
50
59
y_v = np .dot (y_v , hue_rotation .T )
51
60
# Projecting back to RGB space.
52
- y_v = np .dot (y_v , np . linalg . inv ( tyiq ) .T )
61
+ y_v = np .dot (y_v , inverse_tyiq .T )
53
62
return y_v .reshape (x_np .shape )
54
63
55
64
@@ -59,41 +68,34 @@ def _adjust_hue_in_yiq_tf(x_np, delta_h):
59
68
return y
60
69
61
70
62
- def test_adjust_random_hue_in_yiq ():
63
- x_shapes = [
64
- [2 , 2 , 3 ],
65
- [4 , 2 , 3 ],
66
- [2 , 4 , 3 ],
67
- [2 , 5 , 3 ],
68
- [1000 , 1 , 3 ],
69
- ]
70
- test_styles = [
71
- "all_random" ,
72
- "rg_same" ,
73
- "rb_same" ,
74
- "gb_same" ,
75
- "rgb_same" ,
76
- ]
77
- for x_shape in x_shapes :
78
- for test_style in test_styles :
79
- x_np = np .random .rand (* x_shape ) * 255.0
80
- delta_h = (np .random .rand () * 2.0 - 1.0 ) * np .pi
81
- if test_style == "all_random" :
82
- pass
83
- elif test_style == "rg_same" :
84
- x_np [..., 1 ] = x_np [..., 0 ]
85
- elif test_style == "rb_same" :
86
- x_np [..., 2 ] = x_np [..., 0 ]
87
- elif test_style == "gb_same" :
88
- x_np [..., 2 ] = x_np [..., 1 ]
89
- elif test_style == "rgb_same" :
90
- x_np [..., 1 ] = x_np [..., 0 ]
91
- x_np [..., 2 ] = x_np [..., 0 ]
92
- else :
93
- raise AssertionError ("Invalid test style: %s" % (test_style ))
94
- y_np = _adjust_hue_in_yiq_np (x_np , delta_h )
95
- y_tf = _adjust_hue_in_yiq_tf (x_np , delta_h )
96
- np .testing .assert_allclose (y_tf , y_np , rtol = 2e-4 , atol = 1e-4 )
71
+ @pytest .mark .parametrize (
72
+ "shape" , ([2 , 2 , 3 ], [4 , 2 , 3 ], [2 , 4 , 3 ], [2 , 5 , 3 ], [1000 , 1 , 3 ])
73
+ )
74
+ @pytest .mark .parametrize (
75
+ "style" , ("all_random" , "rg_same" , "rb_same" , "gb_same" , "rgb_same" )
76
+ )
77
+ @pytest .mark .parametrize ("dtype" , [np .float16 , np .float32 , np .float64 ])
78
+ def test_adjust_random_hue_in_yiq (shape , style , dtype ):
79
+ x_np = (np .random .rand (* shape ) * 255.0 ).astype (dtype )
80
+ delta_h = (np .random .rand () * 2.0 - 1.0 ) * np .pi
81
+ if style == "all_random" :
82
+ pass
83
+ elif style == "rg_same" :
84
+ x_np [..., 1 ] = x_np [..., 0 ]
85
+ elif style == "rb_same" :
86
+ x_np [..., 2 ] = x_np [..., 0 ]
87
+ elif style == "gb_same" :
88
+ x_np [..., 2 ] = x_np [..., 1 ]
89
+ elif style == "rgb_same" :
90
+ x_np [..., 1 ] = x_np [..., 0 ]
91
+ x_np [..., 2 ] = x_np [..., 0 ]
92
+ else :
93
+ raise AssertionError ("Invalid test style: %s" % (style ))
94
+ y_np = _adjust_hue_in_yiq_np (x_np , delta_h )
95
+ y_tf = _adjust_hue_in_yiq_tf (x_np , delta_h )
96
+ test_utils .assert_allclose_according_to_type (
97
+ y_tf , y_np , atol = 1e-4 , rtol = 2e-4 , half_rtol = 0.8
98
+ )
97
99
98
100
99
101
@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
@@ -229,41 +231,34 @@ def _adjust_saturation_in_yiq_np(x_np, scale):
229
231
return y_v
230
232
231
233
232
- def test_adjust_random_saturation_in_yiq ():
233
- x_shapes = [
234
- [2 , 2 , 3 ],
235
- [4 , 2 , 3 ],
236
- [2 , 4 , 3 ],
237
- [2 , 5 , 3 ],
238
- [1000 , 1 , 3 ],
239
- ]
240
- test_styles = [
241
- "all_random" ,
242
- "rg_same" ,
243
- "rb_same" ,
244
- "gb_same" ,
245
- "rgb_same" ,
246
- ]
247
- for x_shape in x_shapes :
248
- for test_style in test_styles :
249
- x_np = np .random .rand (* x_shape ) * 255.0
250
- scale = np .random .rand () * 2.0 - 1.0
251
- if test_style == "all_random" :
252
- pass
253
- elif test_style == "rg_same" :
254
- x_np [..., 1 ] = x_np [..., 0 ]
255
- elif test_style == "rb_same" :
256
- x_np [..., 2 ] = x_np [..., 0 ]
257
- elif test_style == "gb_same" :
258
- x_np [..., 2 ] = x_np [..., 1 ]
259
- elif test_style == "rgb_same" :
260
- x_np [..., 1 ] = x_np [..., 0 ]
261
- x_np [..., 2 ] = x_np [..., 0 ]
262
- else :
263
- raise AssertionError ("Invalid test style: %s" % (test_style ))
264
- y_baseline = _adjust_saturation_in_yiq_np (x_np , scale )
265
- y_tf = _adjust_saturation_in_yiq_tf (x_np , scale )
266
- np .testing .assert_allclose (y_tf , y_baseline , rtol = 2e-4 , atol = 1e-4 )
234
+ @pytest .mark .parametrize (
235
+ "shape" , ([2 , 2 , 3 ], [4 , 2 , 3 ], [2 , 4 , 3 ], [2 , 5 , 3 ], [1000 , 1 , 3 ])
236
+ )
237
+ @pytest .mark .parametrize (
238
+ "style" , ("all_random" , "rg_same" , "rb_same" , "gb_same" , "rgb_same" )
239
+ )
240
+ @pytest .mark .parametrize ("dtype" , [np .float16 , np .float32 , np .float64 ])
241
+ def test_adjust_random_saturation_in_yiq (shape , style , dtype ):
242
+ x_np = (np .random .rand (* shape ) * 255.0 ).astype (dtype )
243
+ scale = np .random .rand () * 2.0 - 1.0
244
+ if style == "all_random" :
245
+ pass
246
+ elif style == "rg_same" :
247
+ x_np [..., 1 ] = x_np [..., 0 ]
248
+ elif style == "rb_same" :
249
+ x_np [..., 2 ] = x_np [..., 0 ]
250
+ elif style == "gb_same" :
251
+ x_np [..., 2 ] = x_np [..., 1 ]
252
+ elif style == "rgb_same" :
253
+ x_np [..., 1 ] = x_np [..., 0 ]
254
+ x_np [..., 2 ] = x_np [..., 0 ]
255
+ else :
256
+ raise AssertionError ("Invalid test style: %s" % (style ))
257
+ y_baseline = _adjust_saturation_in_yiq_np (x_np , scale )
258
+ y_tf = _adjust_saturation_in_yiq_tf (x_np , scale )
259
+ test_utils .assert_allclose_according_to_type (
260
+ y_tf , y_baseline , atol = 1e-4 , rtol = 2e-4 , half_rtol = 0.8
261
+ )
267
262
268
263
269
264
def test_invalid_rank ():
0 commit comments