Skip to content

Commit cc376d0

Browse files
Add python implementation of softshrink (#1140)
* Add softshrink python op * Added check.
1 parent 2c1ed4f commit cc376d0

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

tensorflow_addons/activations/softshrink.py

+16
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,19 @@ def _softshrink_grad(op, grad):
4848
return _activation_so.ops.addons_softshrink_grad(
4949
grad, op.inputs[0], op.get_attr("lower"), op.get_attr("upper")
5050
)
51+
52+
53+
def _softshrink_py(x, lower, upper):
54+
if lower > upper:
55+
raise ValueError(
56+
"The value of lower is {} and should"
57+
" not be higher than the value "
58+
"variable upper, which is {} .".format(lower, upper)
59+
)
60+
mask_lower = x < lower
61+
mask_upper = upper < x
62+
mask_middle = tf.logical_not(tf.logical_or(mask_lower, mask_upper))
63+
mask_lower = tf.cast(mask_lower, x.dtype)
64+
mask_upper = tf.cast(mask_upper, x.dtype)
65+
mask_middle = tf.cast(mask_middle, x.dtype)
66+
return x * (1 - mask_middle) - mask_lower * lower - mask_upper * upper

tensorflow_addons/activations/softshrink_test.py

+25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
import tensorflow as tf
2020
from tensorflow_addons.activations import softshrink
21+
from tensorflow_addons.activations.softshrink import _softshrink_py
2122
from tensorflow_addons.utils import test_utils
2223

2324

@@ -53,6 +54,30 @@ def test_theoretical_gradients(self, dtype):
5354
theoretical, numerical = tf.test.compute_gradient(softshrink, [x])
5455
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)
5556

57+
@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
58+
def test_same_as_py_func(self, dtype):
59+
np.random.seed(1234)
60+
for _ in range(20):
61+
self.verify_funcs_are_equivalent(dtype)
62+
63+
def verify_funcs_are_equivalent(self, dtype):
64+
x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype)
65+
x = tf.convert_to_tensor(x_np)
66+
lower = np.random.uniform(-10, 10)
67+
upper = lower + np.random.uniform(0, 10)
68+
69+
with tf.GradientTape(persistent=True) as t:
70+
t.watch(x)
71+
y_native = softshrink(x, lower, upper)
72+
y_py = _softshrink_py(x, lower, upper)
73+
74+
self.assertAllCloseAccordingToType(y_native, y_py, atol=1e-4)
75+
76+
grad_native = t.gradient(y_native, x)
77+
grad_py = t.gradient(y_py, x)
78+
79+
self.assertAllCloseAccordingToType(grad_native, grad_py, atol=1e-4)
80+
5681

5782
if __name__ == "__main__":
5883
tf.test.main()

0 commit comments

Comments
 (0)