Skip to content

Commit a4b2ae4

Browse files
Added pure python implementation of lish (#1138)
1 parent cc376d0 commit a4b2ae4

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tensorflow_addons/activations/lisht.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,7 @@ def lisht(x: types.TensorLike) -> tf.Tensor:
4242
@tf.RegisterGradient("Addons>Lisht")
4343
def _lisht_grad(op, grad):
4444
return _activation_so.ops.addons_lisht_grad(grad, op.inputs[0])
45+
46+
47+
def _lisht_py(x):
48+
return x * tf.math.tanh(x)

tensorflow_addons/activations/lisht_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import tensorflow as tf
2020
from tensorflow_addons.activations import lisht
21+
from tensorflow_addons.activations.lisht import _lisht_py
2122
from tensorflow_addons.utils import test_utils
2223

2324

@@ -42,6 +43,30 @@ def test_theoretical_gradients(self, dtype):
4243
theoretical, numerical = tf.test.compute_gradient(lisht, [x])
4344
self.assertAllCloseAccordingToType(theoretical, numerical, rtol=5e-4, atol=5e-4)
4445

46+
@parameterized.named_parameters(
47+
("float16", np.float16), ("float32", np.float32), ("float64", np.float64)
48+
)
49+
def test_same_as_py_func(self, dtype):
50+
np.random.seed(1234)
51+
for _ in range(20):
52+
self.verify_funcs_are_equivalent(dtype)
53+
54+
def verify_funcs_are_equivalent(self, dtype):
55+
x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype)
56+
x = tf.convert_to_tensor(x_np)
57+
58+
with tf.GradientTape(persistent=True) as t:
59+
t.watch(x)
60+
y_native = lisht(x)
61+
y_py = _lisht_py(x)
62+
63+
self.assertAllCloseAccordingToType(y_native, y_py, atol=1e-4)
64+
65+
grad_native = t.gradient(y_native, x)
66+
grad_py = t.gradient(y_py, x)
67+
68+
self.assertAllCloseAccordingToType(grad_native, grad_py, atol=1e-4)
69+
4570

4671
if __name__ == "__main__":
4772
tf.test.main()

0 commit comments

Comments
 (0)