|
| 1 | +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +=======================================================================*/ |
| 15 | +package org.tensorflow.framework.activations; |
| 16 | + |
| 17 | +import org.tensorflow.DataType; |
| 18 | +import org.tensorflow.Operand; |
| 19 | +import org.tensorflow.op.Ops; |
| 20 | +import org.tensorflow.types.TBool; |
| 21 | +import org.tensorflow.types.family.TFloating; |
| 22 | + |
| 23 | +/** |
| 24 | + * Exponential linear unit. |
| 25 | + * |
| 26 | + * <p>The exponential linear unit (ELU) with <code>alpha > 0</code> is: |
| 27 | + * |
| 28 | + * <p><code>x</code> if <code>x > 0</code> and <code>alpha * (exp(x) - |
| 29 | + * 1)</code> if <code>x < 0</code>. |
| 30 | + * |
| 31 | + * <p>The ELU hyperparameter <code>alpha</code> controls the value to which an ELU saturates for |
| 32 | + * negative net inputs. ELUs diminish the vanishing gradient effect. |
| 33 | + * |
| 34 | + * <p>ELUs have negative values which pushes the mean of the activations closer to zero. Mean |
| 35 | + * activations that are closer to zero enable faster learning as they bring the gradient closer to |
| 36 | + * the natural gradient. ELUs saturate to a negative value when the argument gets smaller. |
| 37 | + * Saturation means a small derivative which decreases the variation and the information that is |
| 38 | + * propagated to the next layer. |
| 39 | + * |
| 40 | + * <p>Example Usage: |
| 41 | + * |
| 42 | + * <pre> |
| 43 | + * Operand<TFloat32> input = ...; |
| 44 | + * ELU<TFloat32> elu = new ELU<>(tf, 2.0f); |
| 45 | + * Operand<TFloat32> result = elu.call(input); |
| 46 | + * </pre> |
| 47 | + * |
| 48 | + * @param <T> the data type of the activation |
| 49 | + * @see <a href="https://arxiv.org/abs/1511.07289">Clevert et al, 2016, Fast and Accurate Deep |
| 50 | + * Network Learning by Exponential Linear Units (ELUs)</a> |
| 51 | + */ |
| 52 | +public class ELU<T extends TFloating> extends Activation<T> { |
| 53 | + |
| 54 | + private static final double ALPHA_DEFAULT = 1.0; |
| 55 | + |
| 56 | + /** A scalar, slope of negative section. */ |
| 57 | + private final double alpha; |
| 58 | + |
| 59 | + /** |
| 60 | + * Creates a new ELU with alpha={@link #ALPHA_DEFAULT}. |
| 61 | + * |
| 62 | + * @param tf the TensorFlow Ops |
| 63 | + */ |
| 64 | + public ELU(Ops tf) { |
| 65 | + this(tf, ALPHA_DEFAULT); |
| 66 | + } |
| 67 | + |
| 68 | + /** |
| 69 | + * Creates a new ELU |
| 70 | + * |
| 71 | + * @param tf the TensorFlow Ops |
| 72 | + * @param alpha A scalar, slope of negative section. It controls the value to which an ELU |
| 73 | + * saturates for negative net inputs. |
| 74 | + */ |
| 75 | + public ELU(Ops tf, double alpha) { |
| 76 | + super(tf); |
| 77 | + this.alpha = alpha; |
| 78 | + } |
| 79 | + |
| 80 | + /** |
| 81 | + * Gets the calculation operation for the activation. |
| 82 | + * |
| 83 | + * @param input the input tensor |
| 84 | + * @return The operand for the activation |
| 85 | + */ |
| 86 | + @Override |
| 87 | + public Operand<T> call(Operand<T> input) { |
| 88 | + |
| 89 | + Operand<T> result = tf.nn.elu(input); |
| 90 | + if (alpha == 1.0) return result; |
| 91 | + else { |
| 92 | + DataType<T> dataType = input.asOutput().dataType(); |
| 93 | + Operand<T> y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), dataType)); |
| 94 | + Operand<TBool> cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), dataType)); |
| 95 | + return tf.select(cond, result, y); |
| 96 | + } |
| 97 | + } |
| 98 | +} |
0 commit comments