18
18
import numpy as np
19
19
import tensorflow as tf
20
20
from tensorflow_addons .activations import tanhshrink
21
+ from tensorflow_addons .activations .tanhshrink import _tanhshrink_py
21
22
from tensorflow_addons .utils import test_utils
22
23
23
24
@@ -26,22 +27,22 @@ class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase):
26
27
@parameterized .named_parameters (
27
28
("float16" , np .float16 ), ("float32" , np .float32 ), ("float64" , np .float64 )
28
29
)
29
- def test_tanhshrink (self , dtype ):
30
- x = tf . constant ([ - 2.0 , - 1.0 , 0.0 , 1.0 , 2.0 ], dtype = dtype )
31
- expected_result = tf . constant (
32
- [ - 1.0359724 , - 0.23840582 , 0.0 , 0.23840582 , 1.0359724 ], dtype = dtype
33
- )
34
-
35
- self . assertAllCloseAccordingToType ( tanhshrink ( x ), expected_result )
36
-
37
- @ parameterized . named_parameters (( "float32" , np . float32 ), ( "float64" , np . float64 ))
38
- def test_theoretical_gradients ( self , dtype ):
39
- # Only test theoretical gradients for float32 and float64
40
- # because of the instability of float16 while computing jacobian
41
- x = tf . constant ([ - 2.0 , - 1.0 , 0.0 , 1.0 , 2.0 ], dtype = dtype )
42
-
43
- theoretical , numerical = tf . test . compute_gradient ( tanhshrink , [ x ] )
44
- self .assertAllCloseAccordingToType (theoretical , numerical , atol = 1e-4 )
30
+ def test_same_as_py_func (self , dtype ):
31
+ np . random . seed ( 1234 )
32
+ for _ in range ( 20 ):
33
+ self . verify_funcs_are_equivalent ( dtype )
34
+
35
+ def verify_funcs_are_equivalent ( self , dtype ):
36
+ x_np = np . random . uniform ( - 10 , 10 , size = ( 4 , 4 )). astype ( dtype )
37
+ x = tf . convert_to_tensor ( x_np )
38
+ with tf . GradientTape ( persistent = True ) as t :
39
+ t . watch ( x )
40
+ y_native = tanhshrink ( x )
41
+ y_py = _tanhshrink_py ( x )
42
+ self . assertAllCloseAccordingToType ( y_native , y_py , atol = 1e-4 )
43
+ grad_native = t . gradient ( y_native , x )
44
+ grad_py = t . gradient ( y_py , x )
45
+ self .assertAllCloseAccordingToType (grad_native , grad_py , atol = 1e-4 )
45
46
46
47
47
48
if __name__ == "__main__" :
0 commit comments