18
18
import numpy as np
19
19
import tensorflow as tf
20
20
from tensorflow_addons .activations import lisht
21
+ from tensorflow_addons .activations .lisht import _lisht_py
21
22
from tensorflow_addons .utils import test_utils
22
23
23
24
@@ -42,6 +43,30 @@ def test_theoretical_gradients(self, dtype):
42
43
theoretical , numerical = tf .test .compute_gradient (lisht , [x ])
43
44
self .assertAllCloseAccordingToType (theoretical , numerical , rtol = 5e-4 , atol = 5e-4 )
44
45
46
+ @parameterized .named_parameters (
47
+ ("float16" , np .float16 ), ("float32" , np .float32 ), ("float64" , np .float64 )
48
+ )
49
+ def test_same_as_py_func (self , dtype ):
50
+ np .random .seed (1234 )
51
+ for _ in range (20 ):
52
+ self .verify_funcs_are_equivalent (dtype )
53
+
54
+ def verify_funcs_are_equivalent (self , dtype ):
55
+ x_np = np .random .uniform (- 10 , 10 , size = (4 , 4 )).astype (dtype )
56
+ x = tf .convert_to_tensor (x_np )
57
+
58
+ with tf .GradientTape (persistent = True ) as t :
59
+ t .watch (x )
60
+ y_native = lisht (x )
61
+ y_py = _lisht_py (x )
62
+
63
+ self .assertAllCloseAccordingToType (y_native , y_py , atol = 1e-4 )
64
+
65
+ grad_native = t .gradient (y_native , x )
66
+ grad_py = t .gradient (y_py , x )
67
+
68
+ self .assertAllCloseAccordingToType (grad_native , grad_py , atol = 1e-4 )
69
+
45
70
46
71
if __name__ == "__main__" :
47
72
tf .test .main ()
0 commit comments