Skip to content

Commit 07febff

Browse files
authored
add tanhshrink_py (#1146)
* add tanhshrink_py * format code * remove useless test cases
1 parent eb416bf commit 07febff

File tree

2 files changed

+21
-16
lines changed

2 files changed

+21
-16
lines changed

tensorflow_addons/activations/tanhshrink.py

+4
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,7 @@ def tanhshrink(x: types.TensorLike) -> tf.Tensor:
3838
@tf.RegisterGradient("Addons>Tanhshrink")
3939
def _tanhshrink_grad(op, grad):
4040
return _activation_so.ops.addons_tanhshrink_grad(grad, op.inputs[0])
41+
42+
43+
def _tanhshrink_py(x):
44+
return x - tf.math.tanh(x)

tensorflow_addons/activations/tanhshrink_test.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import tensorflow as tf
2020
from tensorflow_addons.activations import tanhshrink
21+
from tensorflow_addons.activations.tanhshrink import _tanhshrink_py
2122
from tensorflow_addons.utils import test_utils
2223

2324

@@ -26,22 +27,22 @@ class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase):
2627
@parameterized.named_parameters(
2728
("float16", np.float16), ("float32", np.float32), ("float64", np.float64)
2829
)
29-
def test_tanhshrink(self, dtype):
30-
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
31-
expected_result = tf.constant(
32-
[-1.0359724, -0.23840582, 0.0, 0.23840582, 1.0359724], dtype=dtype
33-
)
34-
35-
self.assertAllCloseAccordingToType(tanhshrink(x), expected_result)
36-
37-
@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
38-
def test_theoretical_gradients(self, dtype):
39-
# Only test theoretical gradients for float32 and float64
40-
# because of the instability of float16 while computing jacobian
41-
x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype)
42-
43-
theoretical, numerical = tf.test.compute_gradient(tanhshrink, [x])
44-
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)
30+
def test_same_as_py_func(self, dtype):
31+
np.random.seed(1234)
32+
for _ in range(20):
33+
self.verify_funcs_are_equivalent(dtype)
34+
35+
def verify_funcs_are_equivalent(self, dtype):
36+
x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype)
37+
x = tf.convert_to_tensor(x_np)
38+
with tf.GradientTape(persistent=True) as t:
39+
t.watch(x)
40+
y_native = tanhshrink(x)
41+
y_py = _tanhshrink_py(x)
42+
self.assertAllCloseAccordingToType(y_native, y_py, atol=1e-4)
43+
grad_native = t.gradient(y_native, x)
44+
grad_py = t.gradient(y_py, x)
45+
self.assertAllCloseAccordingToType(grad_native, grad_py, atol=1e-4)
4546

4647

4748
if __name__ == "__main__":

0 commit comments

Comments
 (0)