diff --git a/tensorflow_addons/activations/hardshrink.py b/tensorflow_addons/activations/hardshrink.py index 5d66a2ef46..7817e60241 100644 --- a/tensorflow_addons/activations/hardshrink.py +++ b/tensorflow_addons/activations/hardshrink.py @@ -48,3 +48,20 @@ def _hardshrink_grad(op, grad): return _activation_so.ops.addons_hardshrink_grad( grad, op.inputs[0], op.get_attr("lower"), op.get_attr("upper") ) + + +def _hardshrink_py( + x: types.TensorLike, lower: Number = -0.5, upper: Number = 0.5 +) -> tf.Tensor: + if lower > upper: + raise ValueError( + "The value of lower is {} and should" + " not be higher than the value " + "variable upper, which is {} .".format(lower, upper) + ) + x = tf.convert_to_tensor(x) + mask_lower = x < lower + mask_upper = upper < x + mask = tf.logical_or(mask_lower, mask_upper) + mask = tf.cast(mask, x.dtype) + return x * mask diff --git a/tensorflow_addons/activations/hardshrink_test.py b/tensorflow_addons/activations/hardshrink_test.py index d5d17201f2..63eff43ec5 100644 --- a/tensorflow_addons/activations/hardshrink_test.py +++ b/tensorflow_addons/activations/hardshrink_test.py @@ -19,6 +19,7 @@ import tensorflow as tf from tensorflow_addons.activations import hardshrink from tensorflow_addons.utils import test_utils +from tensorflow_addons.activations.hardshrink import _hardshrink_py @test_utils.run_all_in_graph_and_eager_modes @@ -53,6 +54,30 @@ def test_theoretical_gradients(self, dtype): theoretical, numerical = tf.test.compute_gradient(hardshrink, [x]) self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) + @parameterized.named_parameters(("float32", np.float32), ("float64", np.float64)) + def test_same_as_py_func(self, dtype): + np.random.seed(1234) + for _ in range(20): + self.verify_funcs_are_equivalent(dtype) + + def verify_funcs_are_equivalent(self, dtype): + x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype) + x = tf.convert_to_tensor(x_np) + lower = np.random.uniform(-10, 10) + upper = lower + np.random.uniform(0, 10) + + with tf.GradientTape(persistent=True) as t: + t.watch(x) + y_native = hardshrink(x, lower, upper) + y_py = _hardshrink_py(x, lower, upper) + + self.assertAllCloseAccordingToType(y_native, y_py, atol=1e-4) + + grad_native = t.gradient(y_native, x) + grad_py = t.gradient(y_py, x) + + self.assertAllCloseAccordingToType(grad_native, grad_py, atol=1e-4) + if __name__ == "__main__": tf.test.main()