From 014ff44faf4588ac41746f8381323cfe4777db19 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 30 Aug 2020 18:42:48 -0400 Subject: [PATCH 01/14] Add ability to change learning rate between steps by adding a Placeholder into each Optimizer. Also, added to each Optimizer a corresponding Tensor that holds the value of the learning rate, and added a feed dictionary that maps the placeholder to the Tensor, so that it can be fed into the runner when running or evaluating. When setLearning rate is called the learning rate tensor and the feed dictionary are updated. --- .../tensorflow/keras/optimizers/AdaDelta.java | 25 +- .../tensorflow/keras/optimizers/AdaGrad.java | 25 +- .../keras/optimizers/AdaGradDA.java | 38 +- .../org/tensorflow/keras/optimizers/Adam.java | 21 +- .../tensorflow/keras/optimizers/Adamax.java | 81 +- .../org/tensorflow/keras/optimizers/Ftrl.java | 56 +- .../tensorflow/keras/optimizers/Nadam.java | 60 +- .../keras/optimizers/OptimizerInterface.java | 22 +- .../keras/optimizers/Optimizers.java | 6 +- .../tensorflow/keras/optimizers/RMSProp.java | 19 +- .../org/tensorflow/keras/optimizers/SGD.java | 22 +- .../keras/optimizers/AdaDeltaTest.java | 345 +++--- .../keras/optimizers/AdaGradDATest.java | 26 +- .../keras/optimizers/AdaGradTest.java | 28 +- .../tensorflow/keras/optimizers/AdamTest.java | 35 +- .../keras/optimizers/AdamaxTest.java | 33 +- .../tensorflow/keras/optimizers/FtrlTest.java | 106 +- .../keras/optimizers/NadamTest.java | 34 +- .../keras/optimizers/RMSPropTest.java | 535 ++++----- .../tensorflow/keras/optimizers/SGDTest.java | 34 +- .../keras/utils/EagerTestSession.java | 92 +- .../keras/utils/GraphTestSession.java | 275 +++-- .../tensorflow/keras/utils/TestSession.java | 1023 ++++++++++++----- 23 files changed, 1765 insertions(+), 1176 deletions(-) diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaDelta.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaDelta.java index b0a9dcf7d68..119a2311f3a 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaDelta.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaDelta.java @@ -31,7 +31,6 @@ * *

Two accumulation steps are required: 1) the accumulation of gradients squared, 2) the * accumulation of updates squared. - * */ public class AdaDelta extends org.tensorflow.framework.optimizers.AdaDelta implements OptimizerInterface { @@ -45,11 +44,9 @@ public class AdaDelta extends org.tensorflow.framework.optimizers.AdaDelta public static final float EPSILON_DEFAULT = 1e-7F; private Map config = new HashMap<>(); - private float learningRate; private List initializers = new ArrayList<>(); - /** * Create an Adadelta optimizer with default name="Adadelta", learning_rate=0.001F, rho=0.95F, and * epsilon=1e-7F @@ -127,7 +124,7 @@ protected Optional prepare(String name) { case 1: return Optional.of(initializers.get(0)); default: - return Optional.of( tf.withSubScope(name).withControlDependencies(initializers).noOp()); + return Optional.of(tf.withSubScope(name).withControlDependencies(initializers).noOp()); } } @@ -146,9 +143,8 @@ public static AdaDelta fromConfig(Ops tf, Map config) { * Create an Adadelta optimizer * * @param tf the tensorflow Ops - * @param config a config object to initialize, the config - * object has keys for "name", "learning_rate", "rho" and "epsilon". If a key is missing the - * default value is used. + * @param config a config object to initialize, the config object has keys for "name", + * "learning_rate", "rho" and "epsilon". If a key is missing the default value is used. */ public static AdaDelta create(Ops tf, Map config) { String name = (String) config.get(NAME_KEY); @@ -171,7 +167,6 @@ public static AdaDelta create(Ops tf, Map config) { * @param epsilon A constant epsilon used to better conditioning the grad update. */ private void initConfig(float learningRate, float rho, float epsilon) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(RHO_RATE_KEY, rho); @@ -183,18 +178,4 @@ private void initConfig(float learningRate, float rho, float epsilon) { public Map getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } - - } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGrad.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGrad.java index 039cf4a0d82..98476fbdce5 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGrad.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGrad.java @@ -17,7 +17,13 @@ import java.util.HashMap; import java.util.Map; import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; + +import org.tensorflow.Operand; +import org.tensorflow.Tensor; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; /** * AdaGrad Optimizer that implements the AdaGrad algorithm. Adagrad is an optimizer with @@ -34,7 +40,6 @@ public class AdaGrad extends org.tensorflow.framework.optimizers.AdaGrad public static final float INITIAL_ACCUM__DEFAULT = 0.1f; private Map config = new HashMap<>(); - private float learningRate; /** * Create an AdaGrad Optimizer with name="Adagrad", learningRate=0.001F, and initial @@ -99,8 +104,9 @@ public AdaGrad(Ops tf, float learningRate, float initialAccumulatorValue) { */ public AdaGrad(Ops tf, String name, float learningRate, float initialAccumulatorValue) { super(assertGraph(tf), name, learningRate, initialAccumulatorValue); - if(initialAccumulatorValue < 0.0F) - throw new IllegalArgumentException( "initial_accumulator_value must be non-negative: " + initialAccumulatorValue); + if (initialAccumulatorValue < 0.0F) + throw new IllegalArgumentException( + "initial_accumulator_value must be non-negative: " + initialAccumulatorValue); initConfig(learningRate, initialAccumulatorValue); } @@ -141,7 +147,6 @@ public static AdaGrad create(Ops tf, Map config) { * @param initialAccumulatorValue the initial Accumulator value */ private void initConfig(float learningRate, float initialAccumulatorValue) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(INITIAL_ACCUM_KEY, initialAccumulatorValue); @@ -152,16 +157,4 @@ private void initConfig(float learningRate, float initialAccumulatorValue) { public Map getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGradDA.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGradDA.java index 2f15024bf56..f7d11697623 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGradDA.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/AdaGradDA.java @@ -14,10 +14,12 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; +import org.tensorflow.op.Ops; + import java.util.HashMap; import java.util.Map; + import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; -import org.tensorflow.op.Ops; /** Optimizer that implements the Adagrad Dual-Averaging algorithm. */ public class AdaGradDA extends org.tensorflow.framework.optimizers.AdaGradDA @@ -33,8 +35,7 @@ public class AdaGradDA extends org.tensorflow.framework.optimizers.AdaGradDA public static final float L1STRENGTH_DEFAULT = 0.0F; public static final float L2STRENGTH_DEFAULT = 0.0F; - private Map config = new HashMap<>(); - private float learningRate; + private final Map config = new HashMap<>(); /** * Create an AdagradDA Optimizer with default values name="adagrad-da". learning_rate=.001, @@ -85,11 +86,12 @@ public AdaGradDA( float l1Strength, float l2Strength) { super(assertGraph(tf), learningRate, initialAccumulatorValue, l1Strength, l2Strength); - if( initialAccumulatorValue < 0.0F) - throw new IllegalArgumentException("initial_accumulator_value must be non-negative: " + initialAccumulatorValue); - if(l1Strength < 0) + if (initialAccumulatorValue < 0.0F) + throw new IllegalArgumentException( + "initial_accumulator_value must be non-negative: " + initialAccumulatorValue); + if (l1Strength < 0) throw new IllegalArgumentException("l1Strength must be non-negative: " + l1Strength); - if(l2Strength < 0) + if (l2Strength < 0) throw new IllegalArgumentException("l2Strength must be non-negative: " + l2Strength); initConfig(learningRate, initialAccumulatorValue, l1Strength, l2Strength); } @@ -112,11 +114,12 @@ public AdaGradDA( float l1Strength, float l2Strength) { super(assertGraph(tf), name, learningRate, initialAccumulatorValue, l1Strength, l2Strength); - if( initialAccumulatorValue < 0.0F) - throw new IllegalArgumentException("initial_accumulator_value must be non-negative: " + initialAccumulatorValue); - if(l1Strength < 0) + if (initialAccumulatorValue < 0.0F) + throw new IllegalArgumentException( + "initial_accumulator_value must be non-negative: " + initialAccumulatorValue); + if (l1Strength < 0) throw new IllegalArgumentException("l1Strength must be non-negative: " + l1Strength); - if(l2Strength < 0) + if (l2Strength < 0) throw new IllegalArgumentException("l2Strength must be non-negative: " + l2Strength); initConfig(learningRate, initialAccumulatorValue, l1Strength, l2Strength); initConfig(learningRate, initialAccumulatorValue, l1Strength, l2Strength); @@ -168,7 +171,6 @@ public static AdaGradDA create(Ops tf, Map config) { */ private void initConfig( float learningRate, float initialAccumulatorValue, float l1Strength, float l2Strength) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(INITIAL_ACCUM_KEY, initialAccumulatorValue); @@ -181,16 +183,4 @@ private void initConfig( public Map getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adam.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adam.java index 5d74c7e27f4..593ddcd88f3 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adam.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adam.java @@ -14,11 +14,12 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; +import org.tensorflow.op.Ops; + import java.util.HashMap; import java.util.Map; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; -import org.tensorflow.op.Ops; /** Adam Optimizer that implements the Adam algorithm. */ public class Adam extends org.tensorflow.framework.optimizers.Adam implements OptimizerInterface { @@ -33,8 +34,7 @@ public class Adam extends org.tensorflow.framework.optimizers.Adam implements Op public static final float BETA_ONE_DEFAULT = 0.9F; public static final float BETA_TWO_DEFAULT = 0.999F; - private float learningRate; - private Map config = new HashMap<>(); + private final Map config = new HashMap<>(); /** * Create an Adam Optimizer @@ -154,7 +154,6 @@ public static Adam create(Ops tf, Map config) { * 1 of the paper. Defaults to 1e-7. */ protected void initConfig(float learningRate, float betaOne, float betaTwo, float epsilon) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(EPSILON_KEY, epsilon); @@ -167,16 +166,4 @@ protected void initConfig(float learningRate, float betaOne, float betaTwo, floa public Map getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adamax.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adamax.java index a976a6e51dd..4158043e95b 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adamax.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Adamax.java @@ -14,29 +14,28 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.Scope; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyAdaMax; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.*; + +import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; + /** Adamax Optimizer that implements the Adamax algorithm. */ public class Adamax extends org.tensorflow.framework.optimizers.Optimizer - implements OptimizerInterface { + implements OptimizerInterface, AutoCloseable { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -51,15 +50,17 @@ public class Adamax extends org.tensorflow.framework.optimizers.Optimizer public static final float BETA_ONE_DEFAULT = 0.9F; public static final float BETA_TWO_DEFAULT = 0.999F; - private Scope scope; - private Map config = new HashMap<>(); + private final Map config = new HashMap<>(); private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; + private final float betaOne; private final float betaTwo; private final float epsilon; - private Constant learningRateConst; private Constant epsilonConst; private Constant betaOneConst; private Constant betaTwoConst; @@ -117,10 +118,14 @@ public Adamax(Ops tf, String name, float learningRate) { public Adamax(Ops tf, float learningRate, float betaOne, float betaTwo, float epsilon) { super(assertGraph(tf)); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; - this.scope = tf.scope(); initConfig(learningRate, betaOne, betaTwo, epsilon); } @@ -138,10 +143,14 @@ public Adamax( Ops tf, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { super(assertGraph(tf), name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; - this.scope = tf.scope(); initConfig(learningRate, betaOne, betaTwo, epsilon); } @@ -191,8 +200,31 @@ public float getLearningRate() { /** {@inheritDoc} */ @Override - public void setLearningRate(float learningRate) { + public final void setLearningRate(float learningRate) { this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** + * Get the Feed Dictionary for the run methods to set the Placeholder values(s) + * + * @return the current Feed Dictionary for the run methods + */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } } /** {@inheritDoc} */ @@ -200,7 +232,6 @@ public void setLearningRate(float learningRate) { protected Optional prepare(String scopeName) { betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); - learningRateConst = tf.constant(learningRate); epsilonConst = tf.constant(epsilon); return Optional.empty(); @@ -238,16 +269,16 @@ protected Op applyDense(Output gradient, Output variable Variable firstMomentSlot = getSlot(variable, FIRST_MOMENT).get(); Variable secondMomentSlot = getSlot(variable, SECOND_MOMENT).get(); return ApplyAdaMax.create( - scope, - (Operand) variable, - (Operand) firstMomentSlot, - (Operand) secondMomentSlot, - (Operand) tf.dtypes.cast(betaOnePower, gradient.dataType()), - (Operand) tf.dtypes.cast(learningRateConst, gradient.dataType()), - (Operand) tf.dtypes.cast(betaOneConst, gradient.dataType()), - (Operand) tf.dtypes.cast(betaTwoConst, gradient.dataType()), - (Operand) tf.dtypes.cast(epsilonConst, gradient.dataType()), - (Operand) gradient); + tf.scope(), + (Operand) variable, + (Operand) firstMomentSlot, + (Operand) secondMomentSlot, + (Operand) tf.dtypes.cast(betaOnePower, gradient.dataType()), + (Operand) tf.dtypes.cast(this.learningRatePlaceholder, gradient.dataType()), + (Operand) tf.dtypes.cast(betaOneConst, gradient.dataType()), + (Operand) tf.dtypes.cast(betaTwoConst, gradient.dataType()), + (Operand) tf.dtypes.cast(epsilonConst, gradient.dataType()), + (Operand) gradient); } /** {@inheritDoc} */ diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Ftrl.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Ftrl.java index db73f60c77e..22dad158a4a 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Ftrl.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Ftrl.java @@ -14,26 +14,28 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.Session; -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyFtrl; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; + /** Ftrl (Follow the Regularized Leader) Optimizer that implements the FTRL algorithm. */ public class Ftrl extends org.tensorflow.framework.optimizers.Optimizer - implements OptimizerInterface { + implements OptimizerInterface, AutoCloseable { public static final String LEARNING_RATE_KEY = "learning_rate"; public static final String LEARNING_RATE_POWER_KEY = "learning_rate_power"; @@ -55,15 +57,18 @@ public class Ftrl extends org.tensorflow.framework.optimizers.Optimizer private final String name; private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; private final float learningRatePower; private final float initialAccumulatorValue; private final float l1RegularizationStrength; private final float l2RegularizationStrength; private final float l2ShrinkageRegularizationStrength; - private Map config = new HashMap<>(); + private final Map config = new HashMap<>(); - private boolean useLocking = true; + private final boolean useLocking = true; /** * Create a Ftrl Optimizer @@ -161,6 +166,11 @@ public Ftrl( super(assertGraph(tf)); this.name = getOptimizerName(); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.learningRatePower = learningRatePower; this.initialAccumulatorValue = initialAccumulatorValue; this.l1RegularizationStrength = l1Strength; @@ -198,6 +208,11 @@ public Ftrl( super(assertGraph(tf), name); this.name = name; this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.learningRatePower = learningRatePower; this.initialAccumulatorValue = initialAccumulatorValue; this.l1RegularizationStrength = l1Strength; @@ -331,7 +346,7 @@ protected Op applyDense(Output gradient, Output variable accumSlot, // accum linearSlot, // linear gradient, // gradient - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), // lr + tf.dtypes.cast(this.learningRatePlaceholder, gradient.dataType()), // lr tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()), // l1 tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()), // l2 tf.dtypes.cast( @@ -360,7 +375,26 @@ public float getLearningRate() { /** {@inheritDoc} */ @Override - public void setLearningRate(float learningRate) { + public final void setLearningRate(float learningRate) { this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java index a2eba4ecb49..f9f796d7738 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java @@ -14,29 +14,25 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import org.tensorflow.DataType; -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; +import org.tensorflow.*; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; +import java.util.*; + +import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; + /** Nadam Optimizer that implements the NAdam algorithm. */ public class Nadam extends org.tensorflow.framework.optimizers.Optimizer - implements OptimizerInterface { + implements OptimizerInterface, AutoCloseable { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -55,6 +51,9 @@ public class Nadam extends org.tensorflow.framework.optimizers.Optimizer private final Map config = new HashMap<>(); private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; private final float betaOne; private final float betaTwo; private final float epsilon; @@ -63,7 +62,6 @@ public class Nadam extends org.tensorflow.framework.optimizers.Optimizer private long iterations = 0; - private Constant learningRateConst; private Constant betaOneConst; private Constant betaTwoConst; private Constant localStepConst; @@ -140,6 +138,11 @@ public Nadam(Ops tf, String name, float learningRate) { public Nadam(Ops tf, float learningRate, float betaOne, float betaTwo, float epsilon) { super(assertGraph(tf)); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -160,6 +163,11 @@ public Nadam( Ops tf, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { super(assertGraph(tf), name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -200,8 +208,31 @@ public float getLearningRate() { /** {@inheritDoc} */ @Override - public void setLearningRate(float learningRate) { + public final void setLearningRate(float learningRate) { this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** + * Get the Feed Dictionary for the run methods to set the Placeholder values(s) + * + * @return the current Feed Dictionary for the run methods + */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } } /** {@inheritDoc} */ @@ -248,7 +279,6 @@ protected Optional prepare(String scopeName) { Constant one = tf.constant(1.0F); Constant point5 = tf.constant(0.5F); - learningRateConst = tf.constant(learningRate); betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); localStepConst = tf.constant(this.iterations + 1); @@ -350,7 +380,7 @@ protected Op applyDense(Output gradient, Output variable tf.math.sub( variable, tf.math.div( - tf.math.mul(tf.dtypes.cast(learningRateConst, dType), m_t_bar), + tf.math.mul(tf.dtypes.cast(this.learningRatePlaceholder, dType), m_t_bar), tf.math.add(tf.math.sqrt(v_t_prime), tf.dtypes.cast(epsilonConst, dType)))); // assign(var, var_t, use_locking=self._use_locking) return tf.assign(variable, var_t, Assign.useLocking(true)); diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java index 0074ecb0f0a..183c71dd976 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java @@ -14,10 +14,11 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.Map; import org.tensorflow.Graph; import org.tensorflow.op.Ops; +import java.util.Map; + /** The main Interface for Keras Optimizers */ public interface OptimizerInterface { @@ -32,8 +33,9 @@ public interface OptimizerInterface { * @throws java.lang.IllegalArgumentException if the TensorFlow Ops does not represent Graph mode */ static Graph assertGraph(Ops tf) { - if(!tf.scope().env().isGraph()) { - throw new IllegalArgumentException("Invalid environment, Optimizers can only be used in Graph Mode"); + if (!tf.scope().env().isGraph()) { + throw new IllegalArgumentException( + "Invalid environment, Optimizers can only be used in Graph Mode"); } return (Graph) tf.scope().env(); } @@ -44,18 +46,4 @@ static Graph assertGraph(Ops tf) { * @return the config object used to initialize the Optimizer */ Map getConfig(); - - /** - * Return the current learning rate - * - * @return the current learning rate - */ - float getLearningRate(); - - /** - * Set the learning rate - * - * @param learningRate the learning rate; - */ - void setLearningRate(float learningRate); } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java index 1facb307b38..aecd8dcf537 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java @@ -22,9 +22,9 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Function; +import java.util.function.Supplier; import java.util.logging.Level; import java.util.logging.Logger; -import java.util.function.Supplier; /** * Functions to get an Optimizer based on String name, an Optimizer class, or lambda function. @@ -79,8 +79,8 @@ public static Optimizer get(Ops tf, Function func) { * * @param optimizerFunction either a String that identifies the Optimizer, an Optimizer class, or * * an Optimizer object. - * @param custom_functions a map of Optimizer lambdas that will be queried if the Optimizer is - * not found in the standard keys + * @param custom_functions a map of Optimizer lambdas that will be queried if the Optimizer is not + * found in the standard keys * @return the Optimizer object */ public static Optimizer get( diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java index c66c6bdd388..03fc4c01f71 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java @@ -14,11 +14,12 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; +import org.tensorflow.op.Ops; + import java.util.HashMap; import java.util.Map; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; -import org.tensorflow.op.Ops; /** RMSProp Optimizer that implements the RMSProp algorithm. */ public class RMSProp extends org.tensorflow.framework.optimizers.RMSProp @@ -37,7 +38,6 @@ public class RMSProp extends org.tensorflow.framework.optimizers.RMSProp public static final boolean CENTERED_DEFAULT = false; private Map config = new HashMap<>(); - private float learningRate; /** * Create an RMSProp Optimizer with the following defaults, name="RMSProp", learning_rate=0.001, @@ -172,7 +172,6 @@ public static RMSProp create(Ops tf, Map config) { */ private void initConfig( float learningRate, float decay, float momentum, float epsilon, boolean centered) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(DECAY_KEY, decay); @@ -186,16 +185,4 @@ private void initConfig( public Map getConfig() { return config; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } } diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java index f89682f6820..5e7155c2ab5 100644 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java +++ b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java @@ -14,10 +14,12 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; +import org.tensorflow.op.Ops; + import java.util.HashMap; import java.util.Map; + import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; -import org.tensorflow.op.Ops; /** Stochastic Gradient Descent and momentum optimizer. */ public class SGD extends org.tensorflow.framework.optimizers.Momentum @@ -32,7 +34,6 @@ public class SGD extends org.tensorflow.framework.optimizers.Momentum public static final boolean NESTEROV_DEFAULT = false; private Map config = new HashMap<>(); - private float learningRate; /** * Create a Stochastic Gradient Descent optimizer using defaults: name="SGD", learning_rate=0.01, @@ -102,7 +103,7 @@ public SGD(Ops tf, String name, float learningRate, float momentum) { */ public SGD(Ops tf, float learningRate, float momentum, boolean useNesterov) { super(assertGraph(tf), learningRate, momentum, useNesterov); - if(momentum < 0 || momentum > 1) + if (momentum < 0 || momentum > 1) throw new IllegalArgumentException("\"momentum\" must be between [0, 1]."); initConfig(learningRate, momentum, useNesterov); } @@ -119,7 +120,7 @@ public SGD(Ops tf, float learningRate, float momentum, boolean useNesterov) { */ public SGD(Ops tf, String name, float learningRate, float momentum, boolean useNesterov) { super(assertGraph(tf), name, learningRate, momentum, useNesterov); - if(momentum < 0 || momentum > 1) + if (momentum < 0 || momentum > 1) throw new IllegalArgumentException("\"momentum\" must be between [0, 1]."); initConfig(learningRate, momentum, useNesterov); } @@ -166,7 +167,6 @@ public static SGD create(Ops tf, Map config) { * @param useNesterov Whether to apply Nesterov momentum. Defaults to `false`. */ private void initConfig(float learningRate, float momentum, boolean useNesterov) { - this.learningRate = learningRate; config.put(NAME_KEY, this.getOptimizerName()); config.put(LEARNING_RATE_KEY, learningRate); config.put(MOMENTUM_KEY, momentum); @@ -179,18 +179,6 @@ public Map getConfig() { return config; } - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public void setLearningRate(float learningRate) { - this.learningRate = learningRate; - } - // overide the momentum name to return "SGD" /** {@inheritDoc} */ @Override diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java index 8a7c8af9fae..e8a3bc14d9b 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java @@ -14,26 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; -import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR; -import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR_UPDATE; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer.GradAndVar; -import static org.tensorflow.keras.optimizers.AdaDelta.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.AdaDelta.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.AdaDelta.LEARNING_RATE_DEFAULT; -import static org.tensorflow.keras.optimizers.AdaDelta.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.AdaDelta.RHO_DEFAULT; -import static org.tensorflow.keras.optimizers.AdaDelta.RHO_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -43,184 +25,181 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR; +import static org.tensorflow.framework.optimizers.AdaDelta.ACCUMULATOR_UPDATE; +import static org.tensorflow.keras.optimizers.AdaDelta.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for AdaDelta Optimizer */ public class AdaDeltaTest { - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; - private int index; + private int index; - public AdaDeltaTest() { - } + public AdaDeltaTest() {} - @BeforeAll - public static void setUpClass() { - } + @BeforeAll + public static void setUpClass() {} - @AfterAll - public static void tearDownClass() { - } + @AfterAll + public static void tearDownClass() {} - @BeforeEach - public void setUp() { - } + @BeforeEach + public void setUp() {} - @AfterEach - public void tearDown() { - } + @AfterEach + public void tearDown() {} - /** - * Test of create method, of class AdaDelta. - */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map config = new HashMap<>(); - config.put(NAME_KEY, "AdaDelta"); - config.put(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - config.put(RHO_RATE_KEY, RHO_DEFAULT); - config.put(EPSILON_KEY, EPSILON_DEFAULT); - AdaDelta expResult = new AdaDelta(tf); - AdaDelta result = AdaDelta.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } + /** Test of create method, of class AdaDelta. */ + @Test + public void testCreate() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + Map config = new HashMap<>(); + config.put(NAME_KEY, "AdaDelta"); + config.put(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); + config.put(RHO_RATE_KEY, RHO_DEFAULT); + config.put(EPSILON_KEY, EPSILON_DEFAULT); + AdaDelta expResult = new AdaDelta(tf); + AdaDelta result = AdaDelta.create(tf, config); + assertEquals(expResult.getConfig(), result.getConfig()); } - - @Test - public void testConstructAdadeltaWithLR() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - AdaDelta opt = new AdaDelta(tf, 1.0F, 0.9F, 1.F); - AdaDelta opt2 = new AdaDelta(tf, 0.1F, 0.9F, 1.F); - AdaDelta opt3 = new AdaDelta(tf, 0.1F, 0.9F, 1e-8F); - String format = "AdaDelta{learningRate=%s, rho=%s, epsilon=%s}"; - String optExpected = String.format(format, 1.0F, 0.9F, 1.F); - String opt2Expected = String.format(format, 0.1F, 0.9F, 1.F); - String opt3Expected = String.format(format, 0.1F, 0.9F, 1e-8F); - - String optString = opt.toString(); - String opt2String = opt2.toString(); - String opt3String = opt3.toString(); - - assertEquals(optExpected, optString); - assertEquals(opt2Expected, opt2String); - assertEquals(opt3Expected, opt3String); - } - + } + + @Test + public void testConstructAdadeltaWithLR() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + AdaDelta opt = new AdaDelta(tf, 1.0F, 0.9F, 1.F); + AdaDelta opt1 = new AdaDelta(tf, "AdaDelta_1", 0.1F, 0.9F, 1.F); + AdaDelta opt2 = new AdaDelta(tf, "AdaDelta_2", 0.1F, 0.9F, 1e-8F); + String format = "AdaDelta{learningRate=%s, rho=%s, epsilon=%s}"; + String optExpected = String.format(format, 1.0F, 0.9F, 1.F); + String opt1Expected = String.format(format, 0.1F, 0.9F, 1.F); + String opt2Expected = String.format(format, 0.1F, 0.9F, 1e-8F); + + String optString = opt.toString(); + String opt1String = opt1.toString(); + String opt2String = opt2.toString(); + + assertEquals(optExpected, optString); + assertEquals(opt1Expected, opt1String); + assertEquals(opt2Expected, opt2String); } - - @Test - public void testConstructAdadeltaWithEpsilonValues() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - AdaDelta opt = new AdaDelta(tf); - Map config = opt.getConfig(); - assertEquals(EPSILON_DEFAULT, (float) config.get(EPSILON_KEY)); - - opt = new AdaDelta(tf, LEARNING_RATE_DEFAULT, RHO_DEFAULT, 1e-8F); - config = opt.getConfig(); - assertEquals(1e-8F, (float) config.get(EPSILON_KEY)); - } + } + + @Test + public void testConstructAdadeltaWithEpsilonValues() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + AdaDelta opt = new AdaDelta(tf); + Map config = opt.getConfig(); + assertEquals(EPSILON_DEFAULT, (float) config.get(EPSILON_KEY)); + + opt = new AdaDelta(tf, "AdaDelta_1", LEARNING_RATE_DEFAULT, RHO_DEFAULT, 1e-8F); + config = opt.getConfig(); + assertEquals(1e-8F, (float) config.get(EPSILON_KEY)); } - - @Test - public void testBasic() { - int num_updates = 4; // # number of ADADELTA steps to perform - float[] grads = {0.2F, 0.1F, 0.01F}; - float[] lrs = {1.0F, 0.5F, 0.1F}; - for (float grad : grads) { - for (float lr : lrs) { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] fgrads = {grad, grad}; - Shape shape = Shape.of(var0_init.length); - Variable var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); - - Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); - - Constant cgrads = tf.constant(fgrads); - - float accum = 0.0F; - float accum_update = 0.0F; - float rho = 0.95F; - float epsilon = 1e-8F; - float epsilon1 = 1e-5F; - - /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); - gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); - gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); - - /* get the Optimizer */ - AdaDelta adaDelta = new AdaDelta(tf, lr, rho, epsilon); - - /** - * apply gradients - */ - Op adadelta_update = adaDelta.applyGradients(gradsAndVars, "AdaDeltaTest"); - - /* Create and validae the shapes of the slota */ - Variable[] slots = new Variable[2]; - Variable[] slotUpdates = new Variable[2]; - - slots[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR).get(); - assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); - - slotUpdates[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); - assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); - - slots[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR).get(); - assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); - - slotUpdates[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); - assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); - - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - /** - * initialize the accumulators - */ - session.run(tf.init()); - - /** - * make sure the variables were initialized properly - */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); - - float[] updates = new float[num_updates]; - float tot_update = 0; - for (int step = 0; step < num_updates; step++) { - session.run(adadelta_update); - accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); - updates[step] = ((float) Math.sqrt(accum_update + epsilon) - * (float) (1 / Math.sqrt(accum + epsilon)) * grad); - accum_update = (accum_update * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); - tot_update += updates[step] * lr; - - for (int i = 0; i < 2; i++) { - session.evaluate(accum, slots[i]); - session.evaluate(accum_update, slotUpdates[i]); - } - - Float[] var0_initUpdate = {var0_init[0] - tot_update, var0_init[1] - tot_update}; - Float[] var1_initUpdate = {var1_init[0] - tot_update, var1_init[1] - tot_update}; - - session.evaluate(var0_initUpdate, var0); - session.evaluate(var1_initUpdate, var1); - - } - - } + } + + @Test + public void testBasic() { + int num_updates = 4; // # number of ADADELTA steps to perform + float[] grads = {0.2F, 0.1F, 0.01F}; + float[] lrs = {1.0F, 0.5F, 0.1F}; + for (float grad : grads) { + for (float lr : lrs) { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] fgrads = {grad, grad}; + Shape shape = Shape.of(var0_init.length); + Variable var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant cgrads = tf.constant(fgrads); + + float accum = 0.0F; + float accum_update = 0.0F; + float rho = 0.95F; + float epsilon = 1e-8F; + float epsilon1 = 1e-5F; + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); + + /* get the Optimizer */ + AdaDelta adaDelta = new AdaDelta(tf, lr, rho, epsilon); + + /** apply gradients */ + Op adadelta_update = adaDelta.applyGradients(gradsAndVars, "AdaDeltaTest"); + + /* Create and validae the shapes of the slota */ + Variable[] slots = new Variable[2]; + Variable[] slotUpdates = new Variable[2]; + + slots[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); + + slotUpdates[0] = adaDelta.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); + + slots[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); + + slotUpdates[1] = adaDelta.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + float[] updates = new float[num_updates]; + float tot_update = 0; + for (int step = 0; step < num_updates; step++) { + session.run(adadelta_update, adaDelta.getFeedDict()); + accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); + updates[step] = + ((float) Math.sqrt(accum_update + epsilon) + * (float) (1 / Math.sqrt(accum + epsilon)) + * grad); + accum_update = + (accum_update * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); + tot_update += updates[step] * lr; + + for (int i = 0; i < 2; i++) { + session.evaluate(accum, slots[i]); + session.evaluate(accum_update, slotUpdates[i]); } + + Float[] var0_initUpdate = {var0_init[0] - tot_update, var0_init[1] - tot_update}; + Float[] var1_initUpdate = {var1_init[0] - tot_update, var1_init[1] - tot_update}; + + session.evaluate(var0_initUpdate, var0); + session.evaluate(var1_initUpdate, var1); + } } + } } - + } } diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java index 3931db4da97..85f4220c4c7 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java @@ -14,20 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.AdaGradDA.INITIAL_ACCUM_KEY; -import static org.tensorflow.keras.optimizers.AdaGradDA.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -37,6 +25,16 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.keras.optimizers.AdaGradDA.INITIAL_ACCUM_KEY; +import static org.tensorflow.keras.optimizers.AdaGradDA.LEARNING_RATE_KEY; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for AdaGradDA Optimizer */ public class AdaGradDATest { @@ -116,7 +114,7 @@ public void testBasic() { session.evaluate(var0_init, var0); session.evaluate(var1_init, var1); - session.run(ada_update); + session.run(ada_update, instance.getFeedDict()); float[] expected0 = {-0.904534F, -1.603567F}; session.evaluate(expected0, var0); float[] expected1 = {-0.094821f, -0.189358f}; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java index 28c45c4c8c3..b6f1d7c88fc 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java @@ -14,21 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; -import static org.tensorflow.framework.optimizers.AdaGrad.ACCUMULATOR; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.AdaGrad.INITIAL_ACCUM_KEY; -import static org.tensorflow.keras.optimizers.AdaGrad.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -41,6 +28,17 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.AdaGrad.ACCUMULATOR; +import static org.tensorflow.keras.optimizers.AdaGrad.INITIAL_ACCUM_KEY; +import static org.tensorflow.keras.optimizers.AdaGrad.LEARNING_RATE_KEY; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for AdaGrad Optimizer */ public class AdaGradTest { private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; @@ -138,7 +136,7 @@ public void testBasic() { session.evaluate(var1_init, var1); for (int step = 0; step < numSteps; step++) { - session.run(ada_update); + session.run(ada_update, instance.getFeedDict()); accum0_np = caclulateAccum(accum0_np, grads0_np); var0_np = calculate(var0_np, accum0_np, grads0_np, learningRate); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java index 6f1d13d83d6..6a8f0f5078c 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java @@ -14,29 +14,9 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.Tensor; -import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; -import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.Adam.BETA_ONE_DEFAULT; -import static org.tensorflow.keras.optimizers.Adam.BETA_ONE_KEY; -import static org.tensorflow.keras.optimizers.Adam.BETA_TWO_DEFAULT; -import static org.tensorflow.keras.optimizers.Adam.BETA_TWO_KEY; -import static org.tensorflow.keras.optimizers.Adam.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.Adam.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.Adam.LEARNING_RATE_DEFAULT; -import static org.tensorflow.keras.optimizers.Adam.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -49,6 +29,17 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; +import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; +import static org.tensorflow.keras.optimizers.Adam.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for Adam Optimizer */ public class AdamTest { private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; @@ -203,7 +194,7 @@ public void testBasic() { assertEquals(powers[1], f.getFloat(), epsilon1); }); } - session.run(update); + session.run(update, instance.getFeedDict()); float lr_t = learningRate diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java index 1d3dc9e76bf..3f6b232c179 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java @@ -14,29 +14,9 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.Tensor; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.Adamax.BETA_ONE_DEFAULT; -import static org.tensorflow.keras.optimizers.Adamax.BETA_ONE_KEY; -import static org.tensorflow.keras.optimizers.Adamax.BETA_TWO_DEFAULT; -import static org.tensorflow.keras.optimizers.Adamax.BETA_TWO_KEY; -import static org.tensorflow.keras.optimizers.Adamax.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.Adamax.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.Adamax.FIRST_MOMENT; -import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_DEFAULT; -import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.Adamax.SECOND_MOMENT; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -49,6 +29,15 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.keras.optimizers.Adamax.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for Adamax Optimizer */ public class AdamaxTest { private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; @@ -195,7 +184,7 @@ public void testBasic() { assertEquals(beta1_power, f.getFloat(), epsilon1); }); } - session.run(update); + session.run(update, instance.getFeedDict()); FloatNdArray[] resultNP = calculate(var0_np, grads0_np, step, m0, v0); var0_np = resultNP[VAR]; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/FtrlTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/FtrlTest.java index d61197348af..ba5d7ccb7a2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/FtrlTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/FtrlTest.java @@ -14,24 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.Ftrl.INITIAL_ACCUM_VALUE_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.L1STRENGTH_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.L2STRENGTH_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.L2_SHRINKAGE_REGULARIZATION_STRENGTH_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_POWER_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -41,8 +25,18 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; -/** Test cases for Ftrl Optimizer */ +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.keras.optimizers.Ftrl.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + +/** Test the Ftrl Optimizer */ public class FtrlTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; int index; @@ -147,7 +141,7 @@ public void testFtrlWithL1_L2_L2Shrinkage() { session.evaluate(var1_init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrl_update); + session.run(ftrl_update, instance.getFeedDict()); } float[] expectedVar0 = {-0.22578995F, -0.44345796F}; @@ -214,7 +208,7 @@ public void testFtrlWithL1() { session.evaluate(var1_init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrl_update); + session.run(ftrl_update, instance.getFeedDict()); } float[] expectedVar0 = {-7.66718769F, -10.91273689F}; @@ -282,7 +276,7 @@ public void testFtrlWithL1_L2() { session.evaluate(var1_init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrl_update); + session.run(ftrl_update, instance.getFeedDict()); } float[] expectedVar0 = {-0.24059935F, -0.46829352F}; @@ -293,6 +287,74 @@ public void testFtrlWithL1_L2() { } } + @Test + public void testChangingLearningRate() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + int numSteps = 10; + float learningRate = 3.0F; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {4.0F, 3.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.02F}; + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); + + Ftrl instance = + new Ftrl( + tf, + learningRate, + Ftrl.LEARNING_RATE_POWER_DEFAULT, // learningRatePower + 0.1F, // initial_accumulator_value + 0.001F, // l1_regularization_strength + 2.0F, // l2_regularization_strength + Ftrl + .L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT // l2_shrinkage_regularization_strength + ); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op ftrl_update = instance.applyGradients(gradsAndVars, "FtrlTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + float expected[][][] = { + {{-0.022833f, -0.038881f}, {-0.002141f, -0.004474f}}, + {{-0.037825f, -0.067760f}, {-0.003717f, -0.007587f}}, + {{-0.019528f, -0.034022f}, {-0.001979f, -0.004008f}}, + {{-0.003895f, -0.007653f}, {-0.000355f, -0.000720f}}, + {{-0.000596f, -0.001364f}, {-0.000046f, -0.000094f}}, + {{-0.000084f, -0.000221f}, {-0.000006f, -0.000012f}}, + {{-0.000011f, -0.000034f}, {-0.000001f, -0.000001f}}, + {{-0.000002f, -0.000005f}, {-0.000000f, -0.000000f}}, + {{-0.000000f, -0.000001f}, {-0.000000f, -0.000000f}}, + {{-0.000000f, -0.000000f}, {-0.000000f, -0.000000f}} + }; + for (int i = 0; i < numSteps; i++) { + session.run(ftrl_update, instance.getFeedDict()); + session.evaluate(expected[i][0], var0); + session.evaluate(expected[i][1], var1); + learningRate *= 0.1f; + instance.setLearningRate(learningRate); + } + } + } + @Test public void doTestFtrlwithoutRegularization() { float[] var0_init = {0.0F, 0.0F}; @@ -339,7 +401,7 @@ public void doTestFtrlwithoutRegularization() { session.evaluate(var1_init, var1); for (int i = 0; i < numSteps; i++) { - session.run(ftrl_update); + session.run(ftrl_update, instance.getFeedDict()); } float[] expectedVar0 = {-2.60260963F, -4.29698515F}; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java index 2b8bce40471..6314b4b8b4c 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java @@ -14,29 +14,9 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.Tensor; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.Nadam.BETA_ONE_DEFAULT; -import static org.tensorflow.keras.optimizers.Nadam.BETA_ONE_KEY; -import static org.tensorflow.keras.optimizers.Nadam.BETA_TWO_DEFAULT; -import static org.tensorflow.keras.optimizers.Nadam.BETA_TWO_KEY; -import static org.tensorflow.keras.optimizers.Nadam.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.Nadam.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.Nadam.FIRST_MOMENT; -import static org.tensorflow.keras.optimizers.Nadam.LEARNING_RATE_DEFAULT; -import static org.tensorflow.keras.optimizers.Nadam.SECOND_MOMENT; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -49,6 +29,16 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_KEY; +import static org.tensorflow.keras.optimizers.Nadam.*; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; + /** Test cases for Nadam Optimizer */ public class NadamTest { private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; @@ -199,7 +189,7 @@ public void testBasic() { for (int step = 0; step < numSteps; step++) { - session.run(update); + session.run(update, instance.getFeedDict()); float mut = Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java index b8fb4f40ee9..2a43bdb3df2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java @@ -14,30 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.framework.optimizers.RMSProp.MG; -import static org.tensorflow.framework.optimizers.RMSProp.MOMENTUM; -import static org.tensorflow.framework.optimizers.RMSProp.RMS; -import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.CENTERED_DEFAULT; -import static org.tensorflow.keras.optimizers.RMSProp.CENTERED_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.DECAY_DEFAULT; -import static org.tensorflow.keras.optimizers.RMSProp.DECAY_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.EPSILON_DEFAULT; -import static org.tensorflow.keras.optimizers.RMSProp.EPSILON_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.MOMENTUM_DEFAULT; -import static org.tensorflow.keras.optimizers.RMSProp.MOMENTUM_KEY; import org.tensorflow.keras.utils.ND; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -50,258 +28,281 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; -/** Test cases for RMSProp Optimizer */ -public class RMSPropTest { - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; - - final int VAR_T = 0; - final int MG_T = 1; - final int RMS_T = 2; - final int MOM_T = 3; - - int index; - - public RMSPropTest() { - } - - @BeforeAll - public static void setUpClass() { - } - - @AfterAll - public static void tearDownClass() { - } - - @BeforeEach - public void setUp() { - } +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; - @AfterEach - public void tearDown() { - } +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.RMSProp.*; +import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_KEY; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; +import static org.tensorflow.keras.optimizers.RMSProp.*; - /** - * Test of create method, of class RMSProp. - */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map config = new HashMap<>(); - config.put(NAME_KEY, "Ftrl"); - config.put(LEARNING_RATE_KEY, 2.0F); - config.put(DECAY_KEY, DECAY_DEFAULT); - config.put(MOMENTUM_KEY, MOMENTUM_DEFAULT); - config.put(EPSILON_KEY, EPSILON_DEFAULT); - config.put(CENTERED_KEY, CENTERED_DEFAULT); - Ftrl expResult = new Ftrl(tf, 2.0F); - Ftrl result = Ftrl.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } +/** Test cases for RMSProp Optimizer */ +public class RMSPropTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; + + final int VAR_T = 0; + final int MG_T = 1; + final int RMS_T = 2; + final int MOM_T = 3; + + int index; + + public RMSPropTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + /** Test of create method, of class RMSProp. */ + @Test + public void testCreate() { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + Map config = new HashMap<>(); + config.put(NAME_KEY, "Ftrl"); + config.put(LEARNING_RATE_KEY, 2.0F); + config.put(DECAY_KEY, DECAY_DEFAULT); + config.put(MOMENTUM_KEY, MOMENTUM_DEFAULT); + config.put(EPSILON_KEY, EPSILON_DEFAULT); + config.put(CENTERED_KEY, CENTERED_DEFAULT); + Ftrl expResult = new Ftrl(tf, 2.0F); + Ftrl result = Ftrl.create(tf, config); + assertEquals(expResult.getConfig(), result.getConfig()); } + } + + Object[][] _test_param_values = { + // learning_rate, rho (decay), momentum, epsilon, centered + {0.05F, 0.9F, 0.0F, 1e-3F, true}, + {0.05F, 0.9F, 0.0F, 1e-3F, false}, + {0.1F, 0.9F, 0.0F, 1e-3F, true}, + {0.01F, 0.9F, 0.0F, 1e-5F, true}, + {0.01F, 0.9F, 0.9F, 1e-5F, true} + }; + + @Test + public void testDense() { + + int numSteps = 3; + + for (int run = 0; run < _test_param_values.length; run++) { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + session.setEpsilon(1e-2f); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.2F}; + final float epsilon1 = 1e-2F; + + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); - Object[][] _test_param_values = { // learning_rate, rho (decay), momentum, epsilon, centered - {0.05F, 0.9F, 0.0F, 1e-3F, true}, - {0.05F, 0.9F, 0.0F, 1e-3F, false}, - {0.1F, 0.9F, 0.0F, 1e-3F, true}, - {0.01F, 0.9F, 0.0F, 1e-5F, true}, - {0.01F, 0.9F, 0.9F, 1e-5F, true} - }; - - @Test - public void testDense() { - - int numSteps = 3; - - for (int run = 0; run < _test_param_values.length; run++) { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - session.setEpsilon(1e-2f); - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.2F}; - float[] grads1_init = {0.01F, 0.2F}; - final float epsilon1 = 1e-2F; - - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); - - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - - Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); - - Constant grads0 = tf.constant(grads0_init); - Constant grads1 = tf.constant(grads1_init); - - // learning_rate, rho (decay), momentum, epsilon, centered - float learningRate = (float) (float) _test_param_values[run][0]; - float decay = (float) _test_param_values[run][1]; - float momentum = (float) _test_param_values[run][2]; - float epsilon = (float) _test_param_values[run][3]; - boolean centered = (boolean) _test_param_values[run][4]; - - RMSProp instance = new RMSProp(tf, - learningRate, - decay, - momentum, - epsilon, - centered); - - /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - - Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); - - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - /** - * initialize the accumulators - */ - session.run(tf.init()); - - /** - * make sure the variables were initialized properly - */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); - - Variable mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; - Variable mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; - Variable mom0 = momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; - Variable mom1 = momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; - Variable rms0 = instance.getSlot(var0.asOutput(), RMS).get(); - Variable rms1 = instance.getSlot(var1.asOutput(), RMS).get(); - - float[] zeros = {0.0F, 0.0F}; - float[] ones = {1.0F, 1.0F}; // temp to match RMSProp - FloatNdArray mg0_np = NdArrays.vectorOf(zeros); - FloatNdArray mg1_np = NdArrays.vectorOf(zeros); - FloatNdArray rms0_np = NdArrays.vectorOf(ones); - FloatNdArray rms1_np = NdArrays.vectorOf(ones); - FloatNdArray mom0_np = NdArrays.vectorOf(zeros); - FloatNdArray mom1_np = NdArrays.vectorOf(zeros); - - - - for (int i = 0; i < numSteps; i++) { - session.run(update); - FloatNdArray[] result0 = calc(var0_np, grads0_np, mg0_np, rms0_np, - mom0_np, learningRate, decay, momentum, epsilon, centered); - var0_np = result0[VAR_T]; - mg0_np = result0[MG_T]; - rms0_np = result0[RMS_T]; - mom0_np = result0[MOM_T]; - - FloatNdArray[] result1 = calc(var1_np, grads1_np, mg1_np, rms1_np, - mom1_np, learningRate, decay, momentum, epsilon, centered); - - var1_np = result1[VAR_T]; - mg1_np = result1[MG_T]; - rms1_np = result1[RMS_T]; - mom1_np = result1[MOM_T]; - - if (centered) { - session.evaluate(mg0_np, mg0); - session.evaluate(mg0_np, mg0); - } - if (momentum > 0.F) { - session.evaluate(mom0_np, mom0); - session.evaluate(mom1_np, mom1); - } - - /* TODO the values returned from rms slot, do not match what I see in the python test */ - session.evaluate(rms0_np, rms0); - session.evaluate(rms1_np, rms1); - - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); - } - } - } - } - - FloatNdArray[] calc(FloatNdArray var_np, FloatNdArray grad_np, FloatNdArray mg_np, - FloatNdArray rms_np, FloatNdArray mom, float lr, float decay, float momentum, - float epsilon, boolean centered) { - - FloatNdArray[] result = new FloatNdArray[4]; // var_t, mg_t, rms_t, mom_t - result[RMS_T] = calcRMS(rms_np, grad_np, decay); // RMS - - FloatNdArray denom_t; - if (centered) { - result[MG_T] = calcMG(mg_np, grad_np, decay); - //rms_t - mg_t * mg_t - denom_t = ND.sub(result[RMS_T], ND.square(result[MG_T])); - } else { - result[MG_T] = mg_np; - denom_t = rms_np; - } - if (momentum > 0.F) { - //momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) - result[MOM_T] = calcMom(momentum, mom, lr, grad_np, denom_t, epsilon); - //var_t = var - mom_t - result[VAR_T] = ND.sub(var_np, result[MOM_T]); - } else { - result[MOM_T] = mom; - result[VAR_T] = calcVar(var_np, grad_np, lr, denom_t, epsilon); + float learningRate = (float) (float) _test_param_values[run][0]; + float decay = (float) _test_param_values[run][1]; + float momentum = (float) _test_param_values[run][2]; + float epsilon = (float) _test_param_values[run][3]; + boolean centered = (boolean) _test_param_values[run][4]; + + RMSProp instance = new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + Variable mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; + Variable mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; + Variable mom0 = + momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; + Variable mom1 = + momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; + Variable rms0 = instance.getSlot(var0.asOutput(), RMS).get(); + Variable rms1 = instance.getSlot(var1.asOutput(), RMS).get(); + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0_np = NdArrays.vectorOf(zeros); + FloatNdArray mg1_np = NdArrays.vectorOf(zeros); + FloatNdArray rms0_np = NdArrays.vectorOf(ones); + FloatNdArray rms1_np = NdArrays.vectorOf(ones); + FloatNdArray mom0_np = NdArrays.vectorOf(zeros); + FloatNdArray mom1_np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedDict()); + FloatNdArray[] result0 = + calc( + var0_np, + grads0_np, + mg0_np, + rms0_np, + mom0_np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0_np = result0[VAR_T]; + mg0_np = result0[MG_T]; + rms0_np = result0[RMS_T]; + mom0_np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1_np, + grads1_np, + mg1_np, + rms1_np, + mom1_np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1_np = result1[VAR_T]; + mg1_np = result1[MG_T]; + rms1_np = result1[RMS_T]; + mom1_np = result1[MOM_T]; + + if (centered) { + session.evaluate(mg0_np, mg0); + session.evaluate(mg0_np, mg0); + } + if (momentum > 0.F) { + session.evaluate(mom0_np, mom0); + session.evaluate(mom1_np, mom1); + } + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + session.evaluate(rms0_np, rms0); + session.evaluate(rms1_np, rms1); + + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); } - - - return result; - + } } - - private FloatNdArray calcRMS(FloatNdArray rms_np, FloatNdArray grad_np, float decay) { - //rms * rho + (1 - rho) * g * g - FloatNdArray rms_rho = ND.mul(rms_np, decay); - FloatNdArray squareG = ND.square(grad_np); - float oneRHO = 1.0F - decay; - FloatNdArray decayG2 = ND.mul(oneRHO, squareG); - FloatNdArray result = ND.add(rms_rho, decayG2); - return result; - } - - private FloatNdArray calcMG(FloatNdArray mg_np, FloatNdArray grad_np, float decay) { - //mg_t = mg * rho + (1 - rho) * g - FloatNdArray mg_rho = ND.mul(mg_np, decay); - float oneRHO = 1.0F - decay; - FloatNdArray decayG = ND.mul(oneRHO, grad_np); - FloatNdArray result = ND.add(mg_rho, decayG); - return result; - + } + + FloatNdArray[] calc( + FloatNdArray var_np, + FloatNdArray grad_np, + FloatNdArray mg_np, + FloatNdArray rms_np, + FloatNdArray mom, + float lr, + float decay, + float momentum, + float epsilon, + boolean centered) { + + FloatNdArray[] result = new FloatNdArray[4]; // var_t, mg_t, rms_t, mom_t + result[RMS_T] = calcRMS(rms_np, grad_np, decay); // RMS + + FloatNdArray denom_t; + if (centered) { + result[MG_T] = calcMG(mg_np, grad_np, decay); + // rms_t - mg_t * mg_t + denom_t = ND.sub(result[RMS_T], ND.square(result[MG_T])); + } else { + result[MG_T] = mg_np; + denom_t = rms_np; } - - private FloatNdArray calcMom(float momentum, FloatNdArray mom, float lr, - FloatNdArray grad_np, FloatNdArray denom_t, float epsilon) { - // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) - FloatNdArray moMo = ND.mul(momentum, mom); - FloatNdArray dividend = ND.mul(lr, grad_np); - FloatNdArray divisor = ND.sqrt(ND.add(denom_t, epsilon)); - FloatNdArray quotient = ND.div(dividend, divisor); - FloatNdArray result = ND.add(moMo, quotient); - return result; - + if (momentum > 0.F) { + // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) + result[MOM_T] = calcMom(momentum, mom, lr, grad_np, denom_t, epsilon); + // var_t = var - mom_t + result[VAR_T] = ND.sub(var_np, result[MOM_T]); + } else { + result[MOM_T] = mom; + result[VAR_T] = calcVar(var_np, grad_np, lr, denom_t, epsilon); } - private FloatNdArray calcVar(FloatNdArray var_np, FloatNdArray grad_np, float lr, - FloatNdArray denom_t, float epsilon) { - // var - lr * g / (np.sqrt(denom_t) + epsilon) - FloatNdArray dividend = ND.mul(lr, grad_np); - FloatNdArray divisor = ND.add(ND.sqrt(denom_t), epsilon); - FloatNdArray quotient = ND.div(dividend, divisor); - FloatNdArray result = ND.sub(var_np, quotient); - return result; - - } + return result; + } + + private FloatNdArray calcRMS(FloatNdArray rms_np, FloatNdArray grad_np, float decay) { + // rms * rho + (1 - rho) * g * g + FloatNdArray rms_rho = ND.mul(rms_np, decay); + FloatNdArray squareG = ND.square(grad_np); + float oneRHO = 1.0F - decay; + FloatNdArray decayG2 = ND.mul(oneRHO, squareG); + FloatNdArray result = ND.add(rms_rho, decayG2); + return result; + } + + private FloatNdArray calcMG(FloatNdArray mg_np, FloatNdArray grad_np, float decay) { + // mg_t = mg * rho + (1 - rho) * g + FloatNdArray mg_rho = ND.mul(mg_np, decay); + float oneRHO = 1.0F - decay; + FloatNdArray decayG = ND.mul(oneRHO, grad_np); + FloatNdArray result = ND.add(mg_rho, decayG); + return result; + } + + private FloatNdArray calcMom( + float momentum, + FloatNdArray mom, + float lr, + FloatNdArray grad_np, + FloatNdArray denom_t, + float epsilon) { + // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) + FloatNdArray moMo = ND.mul(momentum, mom); + FloatNdArray dividend = ND.mul(lr, grad_np); + FloatNdArray divisor = ND.sqrt(ND.add(denom_t, epsilon)); + FloatNdArray quotient = ND.div(dividend, divisor); + FloatNdArray result = ND.add(moMo, quotient); + return result; + } + + private FloatNdArray calcVar( + FloatNdArray var_np, FloatNdArray grad_np, float lr, FloatNdArray denom_t, float epsilon) { + // var - lr * g / (np.sqrt(denom_t) + epsilon) + FloatNdArray dividend = ND.mul(lr, grad_np); + FloatNdArray divisor = ND.add(ND.sqrt(denom_t), epsilon); + FloatNdArray quotient = ND.div(dividend, divisor); + FloatNdArray result = ND.sub(var_np, quotient); + return result; + } } diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java index 7e12b957f84..3d24b85239a 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java @@ -14,24 +14,8 @@ =======================================================================*/ package org.tensorflow.keras.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.*; -import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; +import org.junit.jupiter.api.*; import org.tensorflow.framework.optimizers.Optimizer; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.SGD.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.SGD.MOMENTUM_DEFAULT; -import static org.tensorflow.keras.optimizers.SGD.MOMENTUM_KEY; -import static org.tensorflow.keras.optimizers.SGD.NESTEROV_DEFAULT; -import static org.tensorflow.keras.optimizers.SGD.NESTEROV_KEY; import org.tensorflow.keras.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -41,6 +25,16 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; +import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; +import static org.tensorflow.keras.optimizers.SGD.*; + /** Test cases for SGD Optimizer */ public class SGDTest { @@ -134,7 +128,7 @@ public void testBasic() { session.evaluate(var0_init, var0); session.evaluate(var1_init, var1); - session.run(update); // 1 step + session.run(update, instance.getFeedDict()); // 1 step float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; @@ -194,7 +188,7 @@ public void testMomentum() { session.evaluate(var0_init, var0); session.evaluate(var1_init, var1); - session.run(update); // 1 step + session.run(update, instance.getFeedDict()); // 1 step float[] expectedMomentum0 = {0.1F, 0.1F}; float[] expectedMomentum1 = {0.01F, 0.01F}; @@ -206,7 +200,7 @@ public void testMomentum() { session.evaluate(expectedVar0, var0); session.evaluate(expectedVar1, var1); - session.run(update); // step 2 + session.run(update, instance.getFeedDict()); // step 2 float[] expectedMomentum0_2 = {(0.9f * 0.1f + 0.1f), (0.9f * 0.1f + 0.1f)}; float[] expectedMomentum1_2 = {(0.9f * 0.01f + 0.01f), (0.9f * 0.01f + 0.01f)}; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java index 6b7ebf9e2f2..6d286c311ff 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java @@ -14,37 +14,30 @@ =======================================================================*/ package org.tensorflow.keras.utils; -import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -import org.tensorflow.DataType; -import org.tensorflow.EagerSession; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.Session; +import org.tensorflow.*; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TString; +import org.tensorflow.types.*; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -/** Eaager Mode Test Session */ +import java.io.PrintWriter; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Predicate; + +import static org.junit.jupiter.api.Assertions.*; + +/** @author Jim Clarke */ public class EagerTestSession extends TestSession { private final EagerSession session; private final Ops tf; - /** Create an Eager mode test session. */ + /** Create a EagerTestSession */ public EagerTestSession() { this.session = EagerSession.create(); this.tf = Ops.create(session).withName("test"); @@ -57,8 +50,9 @@ public Ops getTF() { } /** - * Get the TensorFlow EagerSession instance - * @return the TensorFlow EagerSession instance + * Returns the EagerSession for this Test session + * + * @return the EagerSession for this Test session */ public EagerSession getSession() { return session; @@ -90,7 +84,22 @@ public EagerSession getEagerSession() { /** {@inheritDoc} */ @Override - public void evaluate(double expected, Operand input) { + public void run(Op op) { + /* Empty */ + } + + /** {@inheritDoc} */ + @Override + public void run(Op op, Map, Tensor> feedDict) { + /* Empty */ + } + + /** {@inheritDoc} */ + @Override + public void evaluate( + double expected, + Operand input, + Map, Tensor> feedDict) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { Operand o = (Operand) input; @@ -169,7 +178,10 @@ public void evaluate(double expected, Operand input) { /** {@inheritDoc} */ @Override - public void evaluate(Number[] expected, Output input) { + public void evaluate( + Number[] expected, + Output input, + Map, Tensor> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -254,7 +266,10 @@ public void evaluate(Number[] expected, Output input) { /** {@inheritDoc} */ @Override - public void evaluate(FloatNdArray expected, Output input) { + public void evaluate( + FloatNdArray expected, + Output input, + Map, Tensor> feedDict) { DataType dtype = input.dataType(); if (dtype == TFloat32.DTYPE) { Output o = (Output) input; @@ -334,7 +349,10 @@ public void evaluate(FloatNdArray expected, Output input) { /** {@inheritDoc} */ @Override - public void evaluate(Output input, Predicate predicate) { + public void evaluate( + Output input, + Predicate predicate, + Map, Tensor> feedDict) { AtomicInteger index = new AtomicInteger(); DataType dtype = input.asOutput().dataType(); boolean isScalar = input.shape().equals(Shape.scalar()); @@ -457,7 +475,10 @@ public void evaluate(Output input, Predicate predic /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output input) { + public void evaluate( + String[] expected, + Output input, + Map, Tensor> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -485,7 +506,10 @@ public void evaluate(String[] expected, Output input) { /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output input) { + public void evaluate( + Boolean[] expected, + Output input, + Map, Tensor> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -513,10 +537,13 @@ public void evaluate(Boolean[] expected, Output input) { /** {@inheritDoc} */ @Override - public void evaluate(Output expected, Output input) { + public void evaluate( + Output expected, + Output input, + Map, Tensor> feedDict) { assert input.shape().equals(expected.shape()) : String.format( - "expected shape (%s) != to input shape (%ds)", + "expected shape (%s) != to input shape (%s)", expected.shape().toString(), input.shape().toString()); DataType dtype = input.asOutput().dataType(); boolean isScalar = input.shape().equals(Shape.scalar()); @@ -683,7 +710,10 @@ public void evaluate(Output expected, Output input) { /** {@inheritDoc} */ @Override - public void print(PrintWriter writer, Output input) { + public void print( + PrintWriter writer, + Output input, + Map, Tensor> feedDict) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { Output o = (Output) input; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java index 1a22289f4bf..ff18b338ce2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java @@ -14,41 +14,32 @@ =======================================================================*/ package org.tensorflow.keras.utils; -import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; -import org.tensorflow.DataType; -import org.tensorflow.EagerSession; -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.Session; -import org.tensorflow.Tensor; +import org.tensorflow.*; +import org.tensorflow.Session.Runner; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.TString; +import org.tensorflow.types.*; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -/** Graph Mode Test Session */ +import java.io.PrintWriter; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Predicate; + +import static org.junit.jupiter.api.Assertions.*; + +/** @author Jim Clarke */ public class GraphTestSession extends TestSession { private final Graph graph; private final Session session; private final Ops tf; - /** Create a Graph mode test session. */ + /** Create a Graph Test Session */ public GraphTestSession() { graph = new Graph(); session = new Session(graph); @@ -61,15 +52,19 @@ public Ops getTF() { return tf; } - /** Get the Graph object that is represented by this Test Session */ + /** + * Get the Graph instance for this test Session + * + * @return + */ public Graph getGraph() { return graph; } - /** - * Get the TensorFlow Session instance - * @return the TensorFlow Session instance + * Get the Graph session instance for this test Session + * + * @return */ public Session getSession() { return session; @@ -119,13 +114,50 @@ public void run(Op op) { /** {@inheritDoc} */ @Override - public void evaluate(double expected, Operand input) { + public void run(Op op, Map, Tensor> feedDict) { + createRunner(op, feedDict).run(); + } + + /** + * Create a runner for the Operation + * + * @param op the operation + * @return the runner + */ + public Runner createRunner(Op op) { + return createRunner(op, null); + } + + /** + * Create a runner for the Operation + * + * @param op the operation + * @param feedDict the dictionary of values to use for the runner's feed operations. Required when + * placeholders are used. + * @return the runner + */ + public Runner createRunner( + Op op, Map, Tensor> feedDict) { + Runner runner = session.runner(); + runner.addTarget(op.op()); + if (feedDict != null && !feedDict.isEmpty()) { + feedDict.forEach((name, tensor) -> runner.feed(name, tensor)); + } + return runner; + } + + /** {@inheritDoc} */ + @Override + public void evaluate( + double expected, + Operand input, + Map, Tensor> feedDict) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -137,7 +169,7 @@ public void evaluate(double expected, Operand input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -150,7 +182,7 @@ public void evaluate(double expected, Operand input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -162,7 +194,7 @@ public void evaluate(double expected, Operand input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -175,7 +207,7 @@ public void evaluate(double expected, Operand input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -187,7 +219,7 @@ public void evaluate(double expected, Operand input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -201,7 +233,7 @@ public void evaluate(double expected, Operand input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -213,7 +245,7 @@ public void evaluate(double expected, Operand input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -229,18 +261,24 @@ public void evaluate(double expected, Operand input) { /** {@inheritDoc} */ @Override - public void evaluate(Number[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + public void evaluate( + Number[] expected, + Output input, + Map, Tensor> feedDict) { + long size = input.shape().size() == 0 ? 1 : input.shape().size(); + if (size != Shape.UNKNOWN_SIZE) { + assertEquals( + expected.length, + size, + () -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); + } DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -252,7 +290,7 @@ public void evaluate(Number[] expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -266,7 +304,7 @@ public void evaluate(Number[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -278,7 +316,7 @@ public void evaluate(Number[] expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -292,7 +330,7 @@ public void evaluate(Number[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -304,7 +342,7 @@ public void evaluate(Number[] expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -318,7 +356,7 @@ public void evaluate(Number[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -330,7 +368,7 @@ public void evaluate(Number[] expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -346,13 +384,16 @@ public void evaluate(Number[] expected, Output input) { /** {@inheritDoc} */ @Override - public void evaluate(FloatNdArray expected, Output input) { + public void evaluate( + FloatNdArray expected, + Output input, + Map, Tensor> feedDict) { DataType dtype = input.asOutput().dataType(); if (dtype == TFloat32.DTYPE) { AtomicLong index = new AtomicLong(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -364,7 +405,7 @@ public void evaluate(FloatNdArray expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { result .data() .scalars() @@ -377,7 +418,7 @@ public void evaluate(FloatNdArray expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -389,7 +430,7 @@ public void evaluate(FloatNdArray expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { result .data() .scalars() @@ -403,7 +444,7 @@ public void evaluate(FloatNdArray expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -415,7 +456,7 @@ public void evaluate(FloatNdArray expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { result .data() .scalars() @@ -429,7 +470,7 @@ public void evaluate(FloatNdArray expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -441,7 +482,7 @@ public void evaluate(FloatNdArray expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { result .data() .scalars() @@ -457,7 +498,10 @@ public void evaluate(FloatNdArray expected, Output input) { /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output input) { + public void evaluate( + String[] expected, + Output input, + Map, Tensor> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -466,7 +510,7 @@ public void evaluate(String[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { result .data() .scalars() @@ -478,7 +522,7 @@ public void evaluate(String[] expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { result .data() .scalars() @@ -491,7 +535,10 @@ public void evaluate(String[] expected, Output input) { /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output input) { + public void evaluate( + Boolean[] expected, + Output input, + Map, Tensor> feedDict) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); assertEquals( expected.length, @@ -500,7 +547,7 @@ public void evaluate(Boolean[] expected, Output input) { AtomicInteger index = new AtomicInteger(); if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { result .data() .scalars() @@ -512,7 +559,7 @@ public void evaluate(Boolean[] expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { result .data() .scalars() @@ -525,10 +572,13 @@ public void evaluate(Boolean[] expected, Output input) { /** {@inheritDoc} */ @Override - public void evaluate(Output expected, Output input) { + public void evaluate( + Output expected, + Output input, + Map, Tensor> feedDict) { assert input.shape().equals(expected.shape()) : String.format( - "expected shape (%s) != to input shape (%ds)", + "expected shape (%s) != to input shape (%s)", expected.shape().toString(), input.shape().toString()); AtomicInteger index = new AtomicInteger(); DataType dtype = input.asOutput().dataType(); @@ -537,9 +587,9 @@ public void evaluate(Output expected, Output input) { final Output finalExpected = (Output) expected; if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { System.out.printf( "0). %f <==> %f\n", expectedResult.data().getFloat(), result.data().getFloat()); @@ -560,9 +610,9 @@ public void evaluate(Output expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getFloat(), result.data().getFloat(), epsilon); } else { @@ -579,9 +629,9 @@ public void evaluate(Output expected, Output input) { final Output finalExpected = (Output) expected; if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { System.out.printf( "0). %f <==> %f\n", expectedResult.data().getDouble(), result.data().getDouble()); @@ -602,9 +652,9 @@ public void evaluate(Output expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getDouble(), result.data().getDouble(), epsilon); } else { @@ -621,9 +671,9 @@ public void evaluate(Output expected, Output input) { final Output finalExpected = (Output) expected; if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { System.out.printf( "0). %d <==> %d\n", expectedResult.data().getInt(), result.data().getInt()); @@ -642,9 +692,9 @@ public void evaluate(Output expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getInt(), result.data().getInt(), epsilon); } else { @@ -661,9 +711,9 @@ public void evaluate(Output expected, Output input) { final Output finalExpected = (Output) expected; if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { System.out.printf( "0). %d <==> %d\n", expectedResult.data().getLong(), result.data().getLong()); @@ -682,9 +732,9 @@ public void evaluate(Output expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getLong(), result.data().getLong(), epsilon); } else { @@ -701,9 +751,9 @@ public void evaluate(Output expected, Output input) { final Output finalExpected = (Output) expected; if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %b\n", expectedResult.data().getBoolean(), result.data().getBoolean()); @@ -724,9 +774,9 @@ public void evaluate(Output expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getBoolean(), result.data().getBoolean()); } else { @@ -743,9 +793,9 @@ public void evaluate(Output expected, Output input) { final Output finalExpected = (Output) expected; if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { if (isScalar) { System.out.printf( "0). %s <==> %s\n", expectedResult.data().getObject(), result.data().getObject()); @@ -766,9 +816,9 @@ public void evaluate(Output expected, Output input) { } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { if (isScalar) { assertEquals(expectedResult.data().getObject(), result.data().getObject()); } else { @@ -787,15 +837,17 @@ public void evaluate(Output expected, Output input) { } /** {@inheritDoc} */ - @Override - public void evaluate(Output input, Predicate predicate) { + public void evaluate( + Output input, + Predicate predicate, + Map, Tensor> feedDict) { AtomicInteger index = new AtomicInteger(); DataType dtype = input.asOutput().dataType(); boolean isScalar = input.shape().equals(Shape.scalar()); if (dtype == TFloat32.DTYPE) { if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", @@ -815,7 +867,7 @@ public void evaluate(Output input, Predicate predic } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getFloat())); } else { @@ -831,7 +883,7 @@ public void evaluate(Output input, Predicate predic } else if (dtype == TFloat64.DTYPE) { if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %f\n", @@ -851,9 +903,9 @@ public void evaluate(Output input, Predicate predic } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getDouble())); } else { @@ -869,7 +921,7 @@ public void evaluate(Output input, Predicate predic } else if (dtype == TInt32.DTYPE) { if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", predicate.test(result.data().getInt()), result.data().getInt()); @@ -888,9 +940,9 @@ public void evaluate(Output input, Predicate predic } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getInt())); } else { @@ -906,7 +958,7 @@ public void evaluate(Output input, Predicate predic } else if (dtype == TInt64.DTYPE) { if (debug) { try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { System.out.printf( "0). %b <==> %d\n", @@ -926,9 +978,9 @@ public void evaluate(Output input, Predicate predic } index.set(0); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE); + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE); Tensor expectedResult = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { assertTrue(predicate.test(result.data().getLong())); } else { @@ -948,14 +1000,17 @@ public void evaluate(Output input, Predicate predic /** {@inheritDoc} */ @Override - public void print(PrintWriter writer, Output input) { + public void print( + PrintWriter writer, + Output input, + Map, Tensor> feedDict) { boolean isScalar = input.asOutput().shape().size() == 1; DataType dtype = input.dataType(); if (dtype == TFloat32.DTYPE) { AtomicInteger index = new AtomicInteger(); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat32.DTYPE)) { if (isScalar) { writer.printf("%d). %f\n", index.getAndIncrement(), result.data().getFloat()); } else { @@ -972,7 +1027,7 @@ public void print(PrintWriter writer, Output input) { AtomicInteger index = new AtomicInteger(); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TFloat64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TFloat64.DTYPE)) { if (isScalar) { writer.printf( "%d). %f\n", index.getAndIncrement(), ((Output) input).data().getDouble()); @@ -990,7 +1045,7 @@ public void print(PrintWriter writer, Output input) { AtomicInteger index = new AtomicInteger(); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt32.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt32.DTYPE)) { if (isScalar) { writer.printf( "%d). %f\n", index.getAndIncrement(), ((Output) input).data().getInt()); @@ -1008,7 +1063,7 @@ public void print(PrintWriter writer, Output input) { AtomicInteger index = new AtomicInteger(); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TInt64.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TInt64.DTYPE)) { if (isScalar) { writer.printf( "%d). %f\n", index.getAndIncrement(), ((Output) input).data().getLong()); @@ -1026,7 +1081,7 @@ public void print(PrintWriter writer, Output input) { AtomicInteger index = new AtomicInteger(); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TBool.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TBool.DTYPE)) { if (isScalar) { writer.printf( "%d). %b\n", index.getAndIncrement(), ((Output) input).data().getBoolean()); @@ -1044,7 +1099,7 @@ public void print(PrintWriter writer, Output input) { AtomicInteger index = new AtomicInteger(); try (Tensor result = - this.getGraphSession().runner().fetch(input).run().get(0).expect(TString.DTYPE)) { + createRunner(input, feedDict).fetch(input).run().get(0).expect(TString.DTYPE)) { if (isScalar) { writer.printf( "%d). %s\n", index.getAndIncrement(), ((Output) input).data().getObject()); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java index 1e5393aa2af..34348ccc1f4 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java @@ -14,16 +14,7 @@ =======================================================================*/ package org.tensorflow.keras.utils; -import java.io.OutputStream; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.io.Writer; -import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.assertTrue; -import org.tensorflow.EagerSession; -import org.tensorflow.Operand; -import org.tensorflow.Output; -import org.tensorflow.Session; +import org.tensorflow.*; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -32,569 +23,1065 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -/** Base class for Test Session */ +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.io.Writer; +import java.util.Map; +import java.util.function.Predicate; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** @author Jim Clarke */ public abstract class TestSession implements AutoCloseable { protected float epsilon = 1e-5F; protected boolean debug; - /** The Test Session mode, either Eager or Graph */ + /** Enumerate between Eager and Graph Mode */ public enum Mode { EAGER, GRAPH; } - /** - * Create an Eager Test Session - * - * @return the Eager Test Session - */ public static TestSession createEagerSession() { return new EagerTestSession(); } - /** - * Create a Graph Test Session - * - * @return the Graph Test Session - */ public static TestSession createGraphSession() { return new GraphTestSession(); } - /** - * Create a Test Session - * - * @param mode - * @return - */ public static TestSession createTestSession(Mode mode) { return mode == Mode.EAGER ? createEagerSession() : createGraphSession(); } - /** Initialize the Test Session, default implementation is do nothing. */ public void initialize() { // empty } /** - * Run the Operation + * Perform session.run() + * + *

If in eager mode, this does nothing. * - * @param op the Operation to run + * @param op The Operation to run */ - public void run(Op op) { - // empty + public abstract void run(Op op); + + /** + * Perform session.run() + * + *

If in eager mode, this does nothing. + * + * @param op The Operation to run + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public abstract void run(Op op, Map, Tensor> feedDict); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param the data type of the input + */ + public void evaluate(Number expected, Operand input) { + evaluate(new Number[] {expected}, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input */ - public void evaluate(Number expected, Operand input) { - evaluate(new Number[] {expected}, input); + public void evaluate( + Number expected, + Operand input, + Map, Tensor> feedDict) { + evaluate(new Number[] {expected}, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value */ - public void evaluate(Number expected, Op input) { - evaluate(new Number[] {expected}, input); + public void evaluate(Number expected, Op input) { + evaluate(new Number[] {expected}, input, null); } /** - * Evaluate the input against the expected values + * Evaluate the expected results versus the actual results * - * @param expected the expected values - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the feedDict entries */ - public void evaluate(Number[] expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); + public void evaluate( + Number expected, Op input, Map, Tensor> feedDict) { + evaluate(new Number[] {expected}, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type for the input + */ + public void evaluate(Number[] expected, Op input) { + Output output = input.op().output(0); + evaluate(expected, output, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the input + */ + public void evaluate( + Number[] expected, + Op input, + Map, Tensor> feedDict) { + Output output = input.op().output(0); + evaluate(expected, output, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param the data type of the input */ - public void evaluate(Number[] expected, Operand input) { + public void evaluate(Number[] expected, Operand input) { + Output output = input.asOutput(); + evaluate(expected, output, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public void evaluate( + Number[] expected, + Operand input, + Map, Tensor> feedDict) { Output output = input.asOutput(); - evaluate(expected, output); + evaluate(expected, output, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type of the input */ - public void evaluate(byte expected, Operand input) { - evaluate((double) expected, input); + public void evaluate(byte expected, Operand input) { + evaluate((double) expected, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input */ - public void evaluate(int expected, Operand input) { - evaluate((double) expected, input); + public void evaluate( + byte expected, + Operand input, + Map, Tensor> feedDict) { + evaluate((double) expected, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type of the input */ - public void evaluate(long expected, Operand input) { - evaluate((double) expected, input); + public void evaluate(int expected, Operand input) { + evaluate((double) expected, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input */ - public void evaluate(float expected, Operand input) { - evaluate((double) expected, input); + public void evaluate( + int expected, + Operand input, + Map, Tensor> feedDict) { + evaluate((double) expected, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type of the input */ - public abstract void evaluate(double expected, Operand input); + public void evaluate(long expected, Operand input) { + evaluate((double) expected, input, null); + } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input */ - public void evaluate(byte[] expected, Operand input) { + public void evaluate( + long expected, + Operand input, + Map, Tensor> feedDict) { + evaluate((double) expected, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param the data type of the input + */ + public void evaluate(float expected, Operand input) { + evaluate((double) expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public void evaluate( + float expected, + Operand input, + Map, Tensor> feedDict) { + evaluate((double) expected, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param the data type of the input + */ + public void evaluate(double expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public abstract void evaluate( + double expected, + Operand input, + Map, Tensor> feedDict); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param the data type of the input + */ + public void evaluate(byte[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public void evaluate( + byte[] expected, + Operand input, + Map, Tensor> feedDict) { Byte[] iArray = new Byte[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type of the input */ - public void evaluate(int[] expected, Operand input) { + public void evaluate(int[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public void evaluate( + int[] expected, + Operand input, + Map, Tensor> feedDict) { Integer[] iArray = new Integer[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type of the input */ - public void evaluate(long[] expected, Operand input) { + public void evaluate(long[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public void evaluate( + long[] expected, + Operand input, + Map, Tensor> feedDict) { Long[] iArray = new Long[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected values + * Evaluate the expected results versus the actual results * - * @param expected the expected values - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param expected the expected value + * @param input the actual value + * @param the data type of the input */ - public void evaluate(float[] expected, Operand input) { + public void evaluate(float[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public void evaluate( + float[] expected, + Operand input, + Map, Tensor> feedDict) { Float[] iArray = new Float[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type of the input + */ + public void evaluate(double[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input */ - public void evaluate(double[] expected, Operand input) { + public void evaluate( + double[] expected, + Operand input, + Map, Tensor> feedDict) { Double[] iArray = new Double[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + for (int i = 0; i < expected.length; i++) { + iArray[i] = expected[i]; + } + evaluate(iArray, input, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type of the input */ - public abstract void evaluate(Number[] expected, Output input); + public void evaluate(Number[] expected, Output input) { + evaluate(expected, input, null); + } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public abstract void evaluate( + Number[] expected, + Output input, + Map, Tensor> feedDict); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(String expected, Operand input) { - evaluate(new String[] {expected}, input); + evaluate(expected, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + String expected, + Operand input, + Map, Tensor> feedDict) { + evaluate(new String[] {expected}, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(String expected, Op input) { - evaluate(new String[] {expected}, input); + evaluate(new String[] {expected}, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + String expected, Op input, Map, Tensor> feedDict) { + evaluate(new String[] {expected}, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(String[] expected, Op input) { + evaluate(expected, input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + String[] expected, + Op input, + Map, Tensor> feedDict) { Output output = input.op().output(0); - evaluate(expected, output); + evaluate(expected, output, feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value */ public void evaluate(String[] expected, Operand input) { Output output = input.asOutput(); - evaluate(expected, output); + evaluate(expected, output, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. */ - public abstract void evaluate(String[] expected, Output input); + public abstract void evaluate( + String[] expected, + Output input, + Map, Tensor> feedDict); /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value */ public void evaluate(Boolean expected, Operand input) { - evaluate(new Boolean[] {expected}, input); + evaluate(new Boolean[] {expected}, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + Boolean expected, + Operand input, + Map, Tensor> feedDict) { + evaluate(new Boolean[] {expected}, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(Boolean expected, Op input) { - evaluate(new Boolean[] {expected}, input); + evaluate(new Boolean[] {expected}, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + Boolean expected, Op input, Map, Tensor> feedDict) { + evaluate(new Boolean[] {expected}, input, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(Boolean[] expected, Op input) { Output output = input.op().output(0); - evaluate(expected, output); + evaluate(expected, output, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void evaluate( + Boolean[] expected, + Op input, + Map, Tensor> feedDict) { + Output output = input.op().output(0); + evaluate(expected, output, feedDict); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value */ public void evaluate(Boolean[] expected, Operand input) { Output output = input.asOutput(); - evaluate(expected, output); + evaluate(expected, output, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. */ - public abstract void evaluate(Boolean[] expected, Output input); + public void evaluate( + Boolean[] expected, + Operand input, + Map, Tensor> feedDict) { + Output output = input.asOutput(); + evaluate(expected, output, feedDict); + } - public void evaluate(Operand expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + */ + public void evaluate(Boolean[] expected, Output input) { + evaluate(expected, input, null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public abstract void evaluate( + Boolean[] expected, + Output input, + Map, Tensor> feedDict); + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + */ + public void evaluate(Operand expected, Output input) { + evaluate(expected.asOutput(), input, null); + } + + /** + * Evaluate the expected results versus the actual results + * + * @param expected the expected value + * @param input the actual value + * @param the data type for the feedDict entries */ public void evaluate(Operand expected, Operand input) { - evaluate(expected.asOutput(), input.asOutput()); + evaluate(expected.asOutput(), input.asOutput(), null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the feedDict entries */ - public abstract void evaluate(Output expected, Output input); + public abstract void evaluate( + Output expected, + Output input, + Map, Tensor> feedDict); /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type of the input */ - public void evaluate(FloatNdArray expected, Operand input) { - evaluate(expected, input.asOutput()); + public void evaluate(FloatNdArray expected, Operand input) { + evaluate(expected, input.asOutput(), null); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input */ - public abstract void evaluate(FloatNdArray expected, Output input); - - public void evaluate(Operand input, Predicate predicate) { - evaluate(input.asOutput(), predicate); + public void evaluate( + FloatNdArray expected, + Operand input, + Map, Tensor> feedDict) { + evaluate(expected, input.asOutput(), feedDict); } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param the data type of the input */ - public abstract void evaluate(Output input, Predicate predicate); + public void evaluate(FloatNdArray expected, Output input) { + evaluate(expected, input, null); + } /** - * Evaluate the input against the expected value + * Evaluate the expected results versus the actual results * * @param expected the expected value - * @param input the operand to evaluate - * @param the data type of the input - * @throws org.opentest4j.AssertionFailedError + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public abstract void evaluate( + FloatNdArray expected, + Output input, + Map, Tensor> feedDict); + + /** + * Evaluate the actual results using a predicate + * + * @param input the actual value + * @param predicate a predicate that accepts a Number as an argument, if the result of the + * predicate is false, then the test will fail + * @param the data type of the input + */ + public void evaluate(Operand input, Predicate predicate) { + evaluate(input.asOutput(), predicate, null); + } + + /** + * Evaluate the actual results using a predicate + * + * @param input the actual value + * @param predicate a predicate that accepts a Number as an argument, if the result of the + * predicate is false, then the test will fail + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type of the input + */ + public abstract void evaluate( + Output input, + Predicate predicate, + Map, Tensor> feedDict); + + /** + * Evaluate the actual results using a predicate + * + * @param input the actual value + * @param predicate a predicate that accepts a Number as an argument, if the result of the + * predicate is false, then the test will fail */ - public void evaluate(FloatNdArray input, Predicate predicate) { + public void evaluate(FloatNdArray input, Predicate predicate) { input.scalars().forEach(f -> assertTrue(predicate.test(f.getFloat()))); } /** - * Print the input + * Print the results to output stream * * @param out the output stream - * @param input the operand to print - * @param the data type of the input + * @param input the actual value + * @param the data type for the input */ public void print(OutputStream out, Operand input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput()); + print(out, input, null); } /** - * Print the input + * Print the results to output stream * * @param out the output stream - * @param input the op to print - * @param the data type of the input + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the feedDict entries */ - public void print(OutputStream out, Op input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0)); + public void print( + OutputStream out, + Operand input, + Map, Tensor> feedDict) { + print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput(), feedDict); } /** - * Print the input + * Print the results to output stream * * @param out the output stream - * @param input the op to print - * @param the data type of the input + * @param input the actual value + */ + public void print(OutputStream out, Op input) { + print(out, input, null); + } + + /** + * Print the results to output stream + * + * @param out the output stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void print( + OutputStream out, Op input, Map, Tensor> feedDict) { + print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0), feedDict); + } + + /** + * Print the results to output stream + * + * @param out the output stream + * @param input the actual value + * @param the data type for the input */ public void print(OutputStream out, Output input) { - print(new PrintWriter(new OutputStreamWriter(out)), input); + print(out, input, null); } /** - * Print the input + * Print the results to output stream * - * @param witer the output writer - * @param input the operand to print - * @param the data type of the input + * @param out the output stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the input + */ + public void print( + OutputStream out, + Output input, + Map, Tensor> feedDict) { + print(new PrintWriter(new OutputStreamWriter(out)), input, feedDict); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param the data type for the input */ public void print(Writer writer, Operand input) { - print(new PrintWriter(writer), input.asOutput()); + print(writer, input, null); } /** - * Print the input + * Print the results to the character stream * - * @param witer the output writer - * @param input the op to print - * @param the data type of the input + * @param writer the character stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the input */ - public void print(Writer writer, Op input) { - print(new PrintWriter(writer), input.op().output(0)); + public void print( + Writer writer, + Operand input, + Map, Tensor> feedDict) { + print(new PrintWriter(writer), input.asOutput(), feedDict); } /** - * Print the input + * Print the results to the character stream * - * @param witer the output writer - * @param input the op to print - * @param the data type of the input + * @param writer the character stream + * @param input the actual value + */ + public void print(Writer writer, Op input) { + print(writer, input, null); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void print( + Writer writer, Op input, Map, Tensor> feedDict) { + print(new PrintWriter(writer), input.op().output(0), feedDict); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param the data type for the input */ public void print(Writer writer, Output input) { - print(new PrintWriter(writer), input); + print(writer, input, null); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the input + */ + public void print( + Writer writer, + Output input, + Map, Tensor> feedDict) { + print(new PrintWriter(writer), input, feedDict); + } + + /** + * Print the results to the character stream + * + * @param writer the character stream + * @param input the actual value + * @param the data type for the input + */ + public void print(PrintWriter writer, Output input) { + print(writer, input, null); } /** - * Print the input + * Print the results to the character stream * - * @param witer the output writer - * @param input the op to print + * @param writer the character stream + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the input */ - public abstract void print(PrintWriter writer, Output input); + public abstract void print( + PrintWriter writer, + Output input, + Map, Tensor> feedDict); /** - * Get the TensorFlow Ops + * Get the TensorFlow Ops for this test session * - * @return the TensorFlow Ops + * @return the TensorFlow Ops for this test session */ public abstract Ops getTF(); /** - * Determine if this Test Session represents an Eager Session + * Determine whether this session is in Eager mode * - * @return true, if this Test Session represents an Eager Session + * @return true if the this session is in Eager mode */ public abstract boolean isEager(); /** - * Determine if this Test Session represents a Graph Session + * Determine whether this session is in Graph mode * - * @return true, if this Test Session represents a Graph Session + * @return true if the this session is in Graph mode */ public boolean isGraph() { return !isEager(); } /** - * Get the epsilon value for evaluating float values + * Get the current EPSILON value for floating point number comparison. * - * @return the epsilon value for evaluating float values + * @return the current EPSILON value for floating point number comparison. */ public float getEpsilon() { return this.epsilon; } /** - * Set the epsilon value for evaluating float values + * Set the current EPSILON value for floating point number comparison. * - * @param epsilon the epsilon value for evaluating float values + * @param epsilon the new EPSILON value for floating point number comparison. */ public void setEpsilon(float epsilon) { this.epsilon = epsilon; } /** - * Get the TensorFlow session object associated with this Test Session + * Get the TensorFlow Session object * - * @return a TensorFlow session if this is a Graph session, otherwise null + * @return the TensorFlow Session object, returns null if this is not a Graph Test Session */ public abstract Session getGraphSession(); /** - * Get the TensorFlow eager session object associated with this Test Session + * Get the TensorFlow EagerSession object * - * @return a TensorFlow session if this is an eager session, otherwise null + * @return the TensorFlow Session object, returns null if this is not a Graph Test Session */ public abstract EagerSession getEagerSession(); @@ -602,15 +1089,21 @@ public void setEpsilon(float epsilon) { @Override public abstract void close(); - /** @return the debug setting */ + /** + * Get the debug setting + * + * @return the debug setting + */ public boolean isDebug() { return debug; } /** - * Set the debug flag + * Sets the debug setting. + * + *

If true, then evaluate methods will also print the Tensor values to System.out. * - * @param debug the setting for debugging + * @param debug the debug to set */ public void setDebug(boolean debug) { this.debug = debug; From fa936ccb8e0675452c95ddf60bb50721f0b400a9 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 30 Aug 2020 18:43:15 -0400 Subject: [PATCH 02/14] Add ability to change learning rate between steps by adding a Placeholder into each Optimizer. Also, added to each Optimizer a corresponding Tensor that holds the value of the learning rate, and added a feed dictionary that maps the placeholder to the Tensor, so that it can be fed into the runner when running or evaluating. When setLearning rate is called the learning rate tensor and the feed dictionary are updated. --- .../framework/optimizers/AdaDelta.java | 54 +++++++++++++++++- .../framework/optimizers/AdaGrad.java | 53 +++++++++++++++++- .../framework/optimizers/AdaGradDA.java | 51 ++++++++++++++++- .../tensorflow/framework/optimizers/Adam.java | 52 +++++++++++++++-- .../framework/optimizers/GradientDescent.java | 54 +++++++++++++++++- .../framework/optimizers/Momentum.java | 53 +++++++++++++++++- .../framework/optimizers/Optimizer.java | 31 +++++++--- .../framework/optimizers/RMSProp.java | 56 ++++++++++++++++++- 8 files changed, 380 insertions(+), 24 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index b5dc2434d60..aa77a6146f9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -15,12 +15,19 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; + import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -33,7 +40,10 @@ public class AdaDelta extends Optimizer { public static final String ACCUMULATOR = "accum"; public static final String ACCUMULATOR_UPDATE = "accum_update"; - private final float learningRate; + private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; private final float rho; @@ -46,6 +56,10 @@ public AdaDelta(Graph graph, float learningRate) { public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.rho = rho; this.epsilon = epsilon; } @@ -57,6 +71,10 @@ public AdaDelta(Graph graph, String name, float learningRate) { public AdaDelta(Graph graph, String name, float learningRate, float rho, float epsilon) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.rho = rho; this.epsilon = epsilon; } @@ -82,7 +100,7 @@ protected Op applyDense(Output gradient, Output variable Variable accumSlot = getSlot(variable, ACCUMULATOR).get(); Variable accumUpdateSlot = getSlot(variable, ACCUMULATOR_UPDATE).get(); return tf.train.applyAdadelta(variable, accumSlot, accumUpdateSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(tf.constant(rho), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), gradient); @@ -101,4 +119,36 @@ public String toString() { public String getOptimizerName() { return "Adadelta"; } + + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 4dfabb21357..7df1ddfa991 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -15,12 +15,19 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; + import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -33,7 +40,10 @@ public class AdaGrad extends Optimizer { public static final String ACCUMULATOR = "accumulator"; - private final float learningRate; + private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; private final float initialAccumulatorValue; @@ -44,6 +54,10 @@ public AdaGrad(Graph graph, float learningRate) { public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.initialAccumulatorValue = initialAccumulatorValue; } @@ -54,6 +68,10 @@ public AdaGrad(Graph graph, String name, float learningRate) { public AdaGrad(Graph graph, String name, float learningRate, float initialAccumulatorValue) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.initialAccumulatorValue = initialAccumulatorValue; } @@ -74,7 +92,7 @@ private void createAdaGradSlot(Output v) { protected Op applyDense(Output gradient, Output variable) { Variable slot = getSlot(variable, ACCUMULATOR).get(); return tf.train - .applyAdagrad(variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + .applyAdagrad(variable, slot, tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), gradient); } @@ -90,4 +108,35 @@ public String toString() { public String getOptimizerName() { return "Adagrad"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index 0544309dc7f..4d590906b2b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -15,15 +15,20 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -36,7 +41,10 @@ public class AdaGradDA extends Optimizer { public static final String ACCUMULATOR = "gradient_accumulator"; public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator"; - private final float learningRate; + private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; private final float initialAccumulatorValue; private final float l1Strength; private final float l2Strength; @@ -50,6 +58,10 @@ public AdaGradDA(Graph graph, float learningRate, float initialAccumulatorValue, float l2Strength) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.initialAccumulatorValue = initialAccumulatorValue; this.l1Strength = l1Strength; this.l2Strength = l2Strength; @@ -63,6 +75,10 @@ public AdaGradDA(Graph graph, String name, float learningRate, float initialAccu float l2Strength) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.initialAccumulatorValue = initialAccumulatorValue; this.l1Strength = l1Strength; this.l2Strength = l2Strength; @@ -97,7 +113,7 @@ protected Op applyDense(Output gradient, Output variable Variable gradSlot = getSlot(variable, ACCUMULATOR).get(); Variable gradSquaredSlot = getSlot(variable, SQUARED_ACCUMULATOR).get(); return tf.train.applyAdagradDa(variable, gradSlot, gradSquaredSlot, gradient, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(tf.constant(l1Strength), gradient.dataType()), tf.dtypes.cast(tf.constant(l2Strength), gradient.dataType()), globalStep); @@ -133,4 +149,35 @@ public String toString() { public String getOptimizerName() { return "adagrad-da"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index 11ab4be6b64..ac07f4e2fc9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -15,17 +15,21 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TFloat32; @@ -42,7 +46,10 @@ public class Adam extends Optimizer { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; - private final float learningRate; + private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; private final float betaOne; @@ -50,7 +57,6 @@ public class Adam extends Optimizer { private final float epsilon; - private Constant learningRateConst; private Constant epsilonConst; private Constant betaOneConst; private Constant betaTwoConst; @@ -64,6 +70,10 @@ public Adam(Graph graph, float learningRate) { public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -76,6 +86,10 @@ public Adam(Graph graph, String name, float learningRate) { public Adam(Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.betaOne = betaOne; this.betaTwo = betaTwo; this.epsilon = epsilon; @@ -121,7 +135,6 @@ protected void createSlots(List> variables) { protected Optional prepare(String scopeName) { betaOneConst = tf.constant(betaOne); betaTwoConst = tf.constant(betaTwo); - learningRateConst = tf.constant(learningRate); epsilonConst = tf.constant(epsilon); return Optional.empty(); } @@ -142,7 +155,7 @@ protected Op applyDense(Output gradient, Output variable return tf.train.applyAdam(variable, firstMomentSlot, secondMomentSlot, tf.dtypes.cast(betaOnePower, gradient.dataType()), tf.dtypes.cast(betaTwoPower, gradient.dataType()), - tf.dtypes.cast(learningRateConst, gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(betaOneConst, gradient.dataType()), tf.dtypes.cast(betaTwoConst, gradient.dataType()), tf.dtypes.cast(epsilonConst, gradient.dataType()), @@ -179,4 +192,35 @@ public String toString() { public String getOptimizerName() { return "Adam"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index 7ed90c846f1..3ba437d241e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -16,31 +16,50 @@ package org.tensorflow.framework.optimizers; import org.tensorflow.Graph; +import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.Collections; +import java.util.Map; + /** * Basic SGD. */ public class GradientDescent extends Optimizer { - private final float learningRate; + private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; public GradientDescent(Graph graph, float learningRate) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); } public GradientDescent(Graph graph, String name, float learningRate) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); } @Override protected Op applyDense(Output gradient, Output variable) { return tf.train.applyGradientDescent(variable, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), gradient); + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), gradient); } @Override @@ -54,4 +73,35 @@ public String toString() { public String getOptimizerName() { return "GradientDescent"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index b8582b4e278..a058649373a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -15,13 +15,20 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; + import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -34,7 +41,10 @@ public class Momentum extends Optimizer { public static final String MOMENTUM = "momentum"; - private final float learningRate; + private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; private final float momentum; @@ -43,6 +53,10 @@ public class Momentum extends Optimizer { public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.momentum = momentum; this.useNesterov = useNesterov; } @@ -50,6 +64,10 @@ public Momentum(Graph graph, float learningRate, float momentum, boolean useNest public Momentum(Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.momentum = momentum; this.useNesterov = useNesterov; } @@ -71,7 +89,7 @@ private void createMomentumSlot(Output v) { protected Op applyDense(Output gradient, Output variable) { Variable slot = getSlot(variable, MOMENTUM).get(); return tf.train - .applyMomentum(variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + .applyMomentum(variable, slot, tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), gradient, tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), ApplyMomentum.useNesterov(useNesterov)); @@ -90,4 +108,35 @@ public String toString() { public String getOptimizerName() { return "Momentum"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index ffff35a8ddd..def464a86ca 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -21,10 +21,8 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Operation; -import org.tensorflow.Output; + +import org.tensorflow.*; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; @@ -36,8 +34,8 @@ /** * Base class for gradient optimizers. */ -public abstract class Optimizer { - +public abstract class Optimizer implements AutoCloseable { + public static final String LEARNING_RATE = "learning_rate"; public static final String VARIABLE_V2 = "VariableV2"; /** * Global state variables @@ -247,7 +245,26 @@ protected Op finish(List updateOperations, String name) { public abstract String getOptimizerName(); /** - * Optional attributes for {@link org.tensorflow.training.optimizers.Optimizer} + * Set the learning rate + * @param learningRate the learning rate + */ + public abstract void setLearningRate(float learningRate); + + /** + * Get the learning rate + * @return the learning rate + */ + public abstract float getLearningRate(); + + /** + * Get the Feed Dictionary for the run methods to set the Placeholder values(s) + * + * @return the current Feed Dictionary for the run methods + */ + public abstract Map, Tensor> getFeedDict(); + + /** + * Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} */ public static class Options { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index cc64a23de3d..3d28c016de7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -15,12 +15,19 @@ */ package org.tensorflow.framework.optimizers; +import java.util.Collections; import java.util.List; +import java.util.Map; + import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -35,7 +42,10 @@ public class RMSProp extends Optimizer { public static final String MG = "mg"; // mean gradient? public static final String MOMENTUM = "momentum"; - private final float learningRate; + private float learningRate; + private Tensor learningRateTensor; + private final Placeholder learningRatePlaceholder; + private Map, Tensor> feedDict; private final float decay; private final float momentum; private final float epsilon; @@ -49,6 +59,11 @@ public RMSProp(Graph graph, float learningRate, float decay, float momentum, flo boolean centered) { super(graph); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; @@ -63,6 +78,10 @@ public RMSProp(Graph graph, String name, float learningRate, float decay, float boolean centered) { super(graph, name); this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; @@ -97,14 +116,14 @@ protected Op applyDense(Output gradient, Output variable if (centered) { Variable mgSlot = getSlot(variable, MG).get(); return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), gradient); } return tf.train.applyRmsProp(variable, rmsSlot, momentumSlot, - tf.dtypes.cast(tf.constant(learningRate), gradient.dataType()), + tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), @@ -126,4 +145,35 @@ public String toString() { public String getOptimizerName() { return "RMSProp"; } + + /** {@inheritDoc} */ + @Override + public float getLearningRate() { + return this.learningRate; + } + + /** {@inheritDoc} */ + @Override + public final void setLearningRate(float learningRate) { + this.learningRate = learningRate; + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + } + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + } + + /** {@inheritDoc} */ + public Map, Tensor> getFeedDict() { + return this.feedDict; + } + + /** {@inheritDoc} */ + @Override + public void close() throws Exception { + if (this.learningRateTensor != null) { + this.learningRateTensor.close(); + this.learningRateTensor = null; + } + } } From 0afdb9ccd17e55941d9f48f78b8bbd7e31ec926a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 1 Sep 2020 12:55:34 -0400 Subject: [PATCH 03/14] Add support for hanling feed dicts when evaluating or printing Operands. --- .../keras/utils/EagerTestSession.java | 2 +- .../keras/utils/GraphTestSession.java | 2 +- .../tensorflow/keras/utils/TestSession.java | 68 ++++++++++++++++++- 3 files changed, 69 insertions(+), 3 deletions(-) diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java index 6d286c311ff..8c0f26b21e7 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/EagerTestSession.java @@ -31,7 +31,7 @@ import static org.junit.jupiter.api.Assertions.*; -/** @author Jim Clarke */ +/** An Eager Mode Test Session */ public class EagerTestSession extends TestSession { private final EagerSession session; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java index ff18b338ce2..98ff9d40c04 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/GraphTestSession.java @@ -32,7 +32,7 @@ import static org.junit.jupiter.api.Assertions.*; -/** @author Jim Clarke */ +/** A Graph Mode Test Session */ public class GraphTestSession extends TestSession { private final Graph graph; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java index 34348ccc1f4..cd4b891a039 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java @@ -32,7 +32,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; -/** @author Jim Clarke */ +/** Abstract class for Test Sessions */ public abstract class TestSession implements AutoCloseable { protected float epsilon = 1e-5F; @@ -851,6 +851,72 @@ public void evaluate(FloatNdArray input, Predicate predicate) { input.scalars().forEach(f -> assertTrue(predicate.test(f.getFloat()))); } + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param the data type for the input + */ + public void print(Operand input) { + print(System.out, input, null); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the feedDict entries + */ + public void print( + Operand input, Map, Tensor> feedDict) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput(), feedDict); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + */ + public void print(Op input) { + print(System.out, input, null); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + */ + public void print(Op input, Map, Tensor> feedDict) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedDict); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param the data type for the input + */ + public void print(Output input) { + print(System.out, input, null); + } + + /** + * Print the results to the "standard" output stream. + * + * @param input the actual value + * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * required for placeholders. + * @param the data type for the input + */ + public void print( + Output input, Map, Tensor> feedDict) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedDict); + } + /** * Print the results to output stream * From 9f10da9e877aeea343b7513bcab1038a03fb9869 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 1 Sep 2020 12:58:59 -0400 Subject: [PATCH 04/14] Add tests for changing learning rates --- .../keras/optimizers/AdaDeltaTest.java | 95 +++++++++++ .../keras/optimizers/AdaGradDATest.java | 67 ++++++++ .../keras/optimizers/AdaGradTest.java | 80 +++++++++ .../tensorflow/keras/optimizers/AdamTest.java | 151 +++++++++++++++++ .../keras/optimizers/AdamaxTest.java | 122 ++++++++++++- .../keras/optimizers/NadamTest.java | 160 +++++++++++++++++- .../keras/optimizers/RMSPropTest.java | 136 +++++++++++++++ .../tensorflow/keras/optimizers/SGDTest.java | 73 ++++++++ 8 files changed, 882 insertions(+), 2 deletions(-) diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java index e8a3bc14d9b..403803295e9 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaDeltaTest.java @@ -202,4 +202,99 @@ public void testBasic() { } } } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 4; // # number of ADADELTA steps to perform + float[] grads = {0.2F, 0.1F, 0.01F}; + + for (float grad : grads) { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + float lr = 1.0F; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] fgrads = {grad, grad}; + Shape shape = Shape.of(var0_init.length); + Variable var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant cgrads = tf.constant(fgrads); + + float accum = 0.0F; + float accum_update = 0.0F; + float rho = 0.95F; + float epsilon = 1e-8F; + float epsilon1 = 1e-5F; + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); + + /* get the Optimizer */ + AdaDelta instance = new AdaDelta(tf, lr, rho, epsilon); + + Op adadelta_update = instance.applyGradients(gradsAndVars, "AdaDeltaTest"); + + /* Create and validae the shapes of the slota */ + Variable[] slots = new Variable[2]; + Variable[] slotUpdates = new Variable[2]; + + slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); + + slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); + + slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); + + slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + float[] updates = new float[numSteps]; + float totUpdate = 0; + for (int step = 0; step < numSteps; step++) { + session.run(adadelta_update, instance.getFeedDict()); + accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); + updates[step] = + ((float) Math.sqrt(accum_update + epsilon) + * (float) (1 / Math.sqrt(accum + epsilon)) + * grad); + accum_update = (accum_update * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); + totUpdate += updates[step] * lr; + + for (int i = 0; i < 2; i++) { + session.evaluate(accum, slots[i]); + session.evaluate(accum_update, slotUpdates[i]); + } + + Float[] var0_initUpdate = {var0_init[0] - totUpdate, var0_init[1] - totUpdate}; + Float[] var1_initUpdate = {var1_init[0] - totUpdate, var1_init[1] - totUpdate}; + + session.evaluate(var0_initUpdate, var0); + session.evaluate(var1_initUpdate, var1); + + // Adjust learning rate + lr *= 0.9F; + instance.setLearningRate(lr); + } + } + } + } } diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java index 85f4220c4c7..98c8145515b 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradDATest.java @@ -121,4 +121,71 @@ public void testBasic() { session.evaluate(expected1, var1); } } + + @Test + public void testWithLearningRateDecay() { + float[] var0_init = {0.0F, 0.0F}; + float[] var1_init = {0.0F, 0.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.02F}; + float epsilon = 1e-8F; + float epsilon1 = 1e-5F; + int numSteps = 4; + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); + + /* initialize the local variables */ + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + float learningRate = 3.0F; + + AdaGrad instance = new AdaGrad(tf, learningRate); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdGradDATest"); + + /** initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + float[][] expected0 = { + {-0.904534F, -1.603567F}, + {-1.683957F, -2.8763597F}, + {-2.3579178F, -3.9125152F}, + {-2.942418F, -4.770327F} + }; + float[][] expected1 = { + {-0.094821F, -0.189358F}, + {-0.18011717F, -0.35944232F}, + {-0.2568455F, -0.51221514F}, + {-0.3258666F, -0.6494397F} + }; + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedDict()); + System.out.println("step: " + i); + session.evaluate(expected0[i], var0); + session.evaluate(expected1[i], var1); + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } } diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java index b6f1d7c88fc..0de84aac0c6 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdaGradTest.java @@ -41,6 +41,7 @@ /** Test cases for AdaGrad Optimizer */ public class AdaGradTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; int index; @@ -149,6 +150,85 @@ public void testBasic() { } } + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + float epsilon = 1e-8F; + float epsilon1 = 1e-5F; + float[] accum0 = {0.1f, 0.1f}; + float[] accum1 = {0.1f, 0.1f}; + + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + FloatNdArray accum0_np = NdArrays.vectorOf(accum0); + FloatNdArray accum1_np = NdArrays.vectorOf(accum1); + + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); + + float learningRate = 3.0F; + + AdaGrad instance = new AdaGrad(tf, learningRate); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op ada_update = instance.applyGradients(gradsAndVars, "AdGradTest"); + + Variable[] accumulatorSlots = new Variable[2]; + accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape()); + + accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + for (int step = 0; step < numSteps; step++) { + session.run(ada_update, instance.getFeedDict()); + + accum0_np = caclulateAccum(accum0_np, grads0_np); + var0_np = calculate(var0_np, accum0_np, grads0_np, learningRate); + session.evaluate(var0_np, var0); + + accum1_np = caclulateAccum(accum1_np, grads1_np); + var1_np = calculate(var1_np, accum1_np, grads1_np, learningRate); + session.evaluate(var1_np, var1); + + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } + private FloatNdArray caclulateAccum(FloatNdArray accum, FloatNdArray grads) { // accum + g_t * g_t FloatNdArray squareG = ND.square(grads); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java index 6a8f0f5078c..67bc9701935 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamTest.java @@ -42,6 +42,7 @@ /** Test cases for Adam Optimizer */ public class AdamTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; int index; @@ -224,6 +225,156 @@ public void testBasic() { } } + @Test + public void testWithLearningRateDecay() { + float m0 = 0.0F; + float v0 = 0.0F; + float m1 = 0.0F; + float v1 = 0.0F; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + float epsilon1 = 1e-3F; + + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + + session.setEpsilon(epsilon1); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + float learningRate = 0.001F; + float beta1 = 0.9F; + float beta2 = 0.999F; + float epsilon = 1e-8F; + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Adam instance = new Adam(tf, learningRate); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + Variable[] firstMomentSlots = new Variable[2]; + Variable[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /** initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + FloatNdArray m0_np = NdArrays.ofFloats(shape1); + FloatNdArray v0_np = NdArrays.ofFloats(shape1); + FloatNdArray m1_np = NdArrays.ofFloats(shape1); + FloatNdArray v1_np = NdArrays.ofFloats(shape1); + + for (int step = 0; step < 3; step++) { + + // Test powers + final float[] powers = { + (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) + }; + + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(powers[0], f.getFloat(), epsilon1); + }); + } + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("beta2_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(powers[1], f.getFloat(), epsilon1); + }); + } + session.run(update, instance.getFeedDict()); + + float lr_t = + learningRate + * (float) Math.sqrt(1 - (float) Math.pow(beta2, (step + 1))) + / (1 - (float) Math.pow(beta1, (step + 1))); + + m0_np = calculateM(m0_np, grads0_np, beta1); + v0_np = calculateV(v0_np, grads0_np, beta2); + var0_np = calculateParam(var0_np, lr_t, m0_np, v0_np, 1e-7F); + + m1_np = calculateM(m1_np, grads1_np, beta1); + v1_np = calculateV(v1_np, grads1_np, beta2); + var1_np = calculateParam(var1_np, lr_t, m1_np, v1_np, 1e-7F); + + // evaluate var 0 and var1 + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + + // first moment + session.evaluate(m0_np, firstMomentSlots[0]); + session.evaluate(m1_np, firstMomentSlots[1]); + + // second moment + session.evaluate(v0_np, secondMomentSlots[0]); + session.evaluate(v1_np, secondMomentSlots[1]); + + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } + private FloatNdArray calculateM(FloatNdArray m, FloatNdArray g_t, float beta) { // m_t = beta1 * m + (1 - beta1) * g_t return ND.add(ND.mul(m, beta), ND.mul(g_t, (1 - beta))); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java index 3f6b232c179..24ec7cb15cc 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/AdamaxTest.java @@ -40,6 +40,7 @@ /** Test cases for Adamax Optimizer */ public class AdamaxTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; private static final int VAR = 0; @@ -197,16 +198,135 @@ public void testBasic() { v1 = resultNP[V]; // evaluate var0 and var1 + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + } + } + } + + @Test + public void testWithLearningRateDecay() { + + float epsilon = 1e-6f; + float epsilon1 = 1e-3F; + int numSteps = 3; + + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + float[] zeros = {0.0F, 0.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + float learningRate = 0.001F; + + Adamax instance = new Adamax(tf, learningRate); + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + Variable[] firstMomentSlots = new Variable[2]; + Variable[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /** initialize the accumulators */ + session.run(tf.init()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + session.setEpsilon(epsilon1); + for (int step = 0; step < numSteps; step++) { + // Test powers + final float beta1_power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(beta1_power, f.getFloat(), epsilon1); + }); + } + session.run(update, instance.getFeedDict()); + + FloatNdArray[] resultNP = calculate(var0_np, grads0_np, step, m0, v0, learningRate); + var0_np = resultNP[VAR]; + m0 = resultNP[M]; + v0 = resultNP[V]; + + resultNP = calculate(var1_np, grads1_np, step, m1, v1, learningRate); + var1_np = resultNP[VAR]; + m1 = resultNP[M]; + v1 = resultNP[V]; + + // evaluate var0 and var1 session.evaluate(var0_np, var0); session.evaluate(var1_np, var1); + + learningRate *= 0.9F; + instance.setLearningRate(learningRate); } } } private FloatNdArray[] calculate( FloatNdArray var_np, FloatNdArray grads_np, int step, FloatNdArray m, FloatNdArray v) { - float alpha = 0.001F; + return calculate(var_np, grads_np, step, m, v, 0.001F); + } + + private FloatNdArray[] calculate( + FloatNdArray var_np, + FloatNdArray grads_np, + int step, + FloatNdArray m, + FloatNdArray v, + float alpha) { float beta1 = BETA_ONE_DEFAULT; float beta2 = BETA_TWO_DEFAULT; float espilon = 1e-8F; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java index 6314b4b8b4c..32d90ea91ed 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java @@ -237,6 +237,154 @@ public void testBasic() { } } + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + FloatNdArray mcache = NdArrays.vectorOf(ones); + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + float epsilon = 1e-6f; + float epsilon1 = 1e-3F; + + float learningRate = 0.001F; + + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); + + Nadam instance = new Nadam(tf, learningRate); + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + Variable[] firstMomentSlots = new Variable[2]; + Variable[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + session.setEpsilon(epsilon1); + + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(1F, f.getFloat(), epsilon1); + }); + } + momentum = 1F; + + for (int step = 0; step < numSteps; step++) { + + session.run(update, instance.getFeedDict()); + + float mut = + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + momentum = momentum * mut; + + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result + .data() + .scalars() + .forEach( + f -> { + assertEquals(momentum, f.getFloat(), epsilon1); + }); + } + mcache = ND.mul(mcache, momentum); + FloatNdArray[] resultsNP = + nadam_update_numpy(var0_np, grads0_np, step, m0, v0, mcache, learningRate); + var0_np = resultsNP[VAR]; + m0 = resultsNP[M]; + v0 = resultsNP[V]; + + resultsNP = nadam_update_numpy(var1_np, grads1_np, step, m1, v1, mcache, learningRate); + var1_np = resultsNP[VAR]; + m1 = resultsNP[M]; + v1 = resultsNP[V]; + + // evaluate m0 and m1 + session.evaluate(m0, firstMomentSlots[0]); + session.evaluate(m1, firstMomentSlots[1]); + + // evaluate v0 and v1 + session.evaluate(v0, secondMomentSlots[0]); + session.evaluate(v1, secondMomentSlots[1]); + + // evaluate var0 and var1 + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + + learningRate *= 0.9; + instance.setLearningRate(learningRate); + } + } + } + private FloatNdArray update_m_cache(FloatNdArray mcache, int t) { float mu_t = 0.9F * (1.0F - 0.5F * (float) Math.pow(0.96, (0.004 * (t + 1)))); return ND.mul(mu_t, mcache); @@ -249,8 +397,18 @@ private FloatNdArray[] nadam_update_numpy( FloatNdArray m, FloatNdArray v, FloatNdArray m_cache) { + return nadam_update_numpy(var_np, grads_np, t, m, v, m_cache, 0.001F); + } + + private FloatNdArray[] nadam_update_numpy( + FloatNdArray var_np, + FloatNdArray grads_np, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray m_cache, + float alpha) { - float alpha = 0.001F; float beta1 = 0.9F; float beta2 = 0.999F; float epsilon = 1e-8F; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java index 2a43bdb3df2..7651872643b 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java @@ -41,6 +41,7 @@ /** Test cases for RMSProp Optimizer */ public class RMSPropTest { + private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; final int VAR_T = 0; @@ -224,6 +225,141 @@ public void testDense() { } } + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + + for (int run = 0; run < _test_param_values.length; run++) { + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + session.setEpsilon(1e-2f); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.2F}; + final float epsilon1 = 1e-2F; + + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); + + // learning_rate, rho (decay), momentum, epsilon, centered + float learningRate = (float) (float) _test_param_values[run][0]; + float decay = (float) _test_param_values[run][1]; + float momentum = (float) _test_param_values[run][2]; + float epsilon = (float) _test_param_values[run][3]; + boolean centered = (boolean) _test_param_values[run][4]; + + RMSProp instance = new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + Variable mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; + Variable mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; + Variable mom0 = + momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; + Variable mom1 = + momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; + Variable rms0 = instance.getSlot(var0.asOutput(), RMS).get(); + Variable rms1 = instance.getSlot(var1.asOutput(), RMS).get(); + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0_np = NdArrays.vectorOf(zeros); + FloatNdArray mg1_np = NdArrays.vectorOf(zeros); + FloatNdArray rms0_np = NdArrays.vectorOf(ones); + FloatNdArray rms1_np = NdArrays.vectorOf(ones); + FloatNdArray mom0_np = NdArrays.vectorOf(zeros); + FloatNdArray mom1_np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedDict()); + FloatNdArray[] result0 = + calc( + var0_np, + grads0_np, + mg0_np, + rms0_np, + mom0_np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0_np = result0[VAR_T]; + mg0_np = result0[MG_T]; + rms0_np = result0[RMS_T]; + mom0_np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1_np, + grads1_np, + mg1_np, + rms1_np, + mom1_np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1_np = result1[VAR_T]; + mg1_np = result1[MG_T]; + rms1_np = result1[RMS_T]; + mom1_np = result1[MOM_T]; + + if (centered) { + session.evaluate(mg0_np, mg0); + session.evaluate(mg0_np, mg0); + } + if (momentum > 0.F) { + session.evaluate(mom0_np, mom0); + session.evaluate(mom1_np, mom1); + } + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + session.evaluate(rms0_np, rms0); + session.evaluate(rms1_np, rms1); + + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + + learningRate *= 0.9F; + instance.setLearningRate(learningRate); + } + } + } + } + FloatNdArray[] calc( FloatNdArray var_np, FloatNdArray grad_np, diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java index 3d24b85239a..1cf20f1b0d2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java +++ b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java @@ -218,4 +218,77 @@ public void testMomentum() { session.evaluate(expectedVar1_2, var1); } } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 2; + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.1F}; + float[] grads1_init = {0.01F, 0.01F}; + + float learningRate = 3.0F; + + float epsilon = 1e-6F; + float epsilon1 = 1e-2F; + try (TestSession session = TestSession.createTestSession(tf_mode)) { + Ops tf = session.getTF(); + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); + + /* build the GradsAnvVars */ + List gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + SGD instance = new SGD(tf, learningRate); + Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + + Variable momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get(); + assertEquals(momentumSlot0.asOutput().shape(), var0.asOutput().shape()); + Variable momentumSlot1 = instance.getSlot(var1.asOutput(), MOMENTUM).get(); + assertEquals(momentumSlot1.asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /** initialize the accumulators */ + session.run(tf.init()); + + /** make sure the variables were initialized properly */ + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + float[][] expectedVar0 = { + {0.7F, 1.7F}, + {0.66999996F, 1.6700001F}, + {0.66699994F, 1.667F}, + {0.66669995F, 1.6667F}, + {0.66666996F, 1.66667F} + }; + float[][] expectedVar1 = { + {2.97F, 3.97F}, + {2.967F, 3.967F}, + {2.9667F, 3.9667F}, + {2.96667F, 3.96667F}, + {2.966667F, 3.966667F} + }; + for (int step = 0; step < numSteps; step++) { + session.run(update, instance.getFeedDict()); + session.evaluate(expectedVar0[step], var0); + session.evaluate(expectedVar1[step], var1); + learningRate *= 0.1; + instance.setLearningRate(learningRate); + } + } + } } From d8fab044e35d973aa80e7cde3765a166f783c38b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 14 Sep 2020 12:45:48 -0400 Subject: [PATCH 05/14] Moved Optimizers to Keras. Added support for chanign learning rate. --- .../src/bazel/op_generator/op_generator.cc | 10 +- .../annotations/org/tensorflow/op/NnOps.java | 139 ++++-- .../java/org/tensorflow/op/core/Abort.java | 1 - .../org/tensorflow/op/core/AssertThat.java | 1 - .../op/core/AssignAddVariableOp.java | 1 - .../op/core/AssignSubVariableOp.java | 1 - .../tensorflow/op/core/AssignVariableOp.java | 1 - .../org/tensorflow/op/core/BarrierClose.java | 1 - .../tensorflow/op/core/BarrierInsertMany.java | 1 - .../tensorflow/op/core/ConsumeMutexLock.java | 1 - .../tensorflow/op/core/ControlTrigger.java | 1 - .../op/core/DeleteSessionTensor.java | 1 - .../tensorflow/op/core/DestroyResourceOp.java | 1 - .../org/tensorflow/op/core/DeviceIndex.java | 3 + .../tensorflow/op/core/InitializeTable.java | 1 - .../op/core/InitializeTableFromTextFile.java | 1 - .../tensorflow/op/core/LookupTableImport.java | 1 - .../tensorflow/op/core/LookupTableInsert.java | 1 - .../tensorflow/op/core/LookupTableRemove.java | 1 - .../java/org/tensorflow/op/core/MapClear.java | 1 - .../java/org/tensorflow/op/core/MapStage.java | 1 - .../gen/java/org/tensorflow/op/core/NoOp.java | 1 - .../tensorflow/op/core/OrderedMapClear.java | 1 - .../tensorflow/op/core/OrderedMapStage.java | 1 - .../java/org/tensorflow/op/core/Print.java | 1 - .../op/core/ResourceScatterAdd.java | 1 - .../op/core/ResourceScatterDiv.java | 1 - .../op/core/ResourceScatterMax.java | 1 - .../op/core/ResourceScatterMin.java | 1 - .../op/core/ResourceScatterMul.java | 1 - .../op/core/ResourceScatterNdAdd.java | 1 - .../op/core/ResourceScatterNdMax.java | 2 + .../op/core/ResourceScatterNdMin.java | 2 + .../op/core/ResourceScatterNdSub.java | 1 - .../op/core/ResourceScatterNdUpdate.java | 1 - .../op/core/ResourceScatterSub.java | 1 - .../op/core/ResourceScatterUpdate.java | 1 - .../op/core/ResourceStridedSliceAssign.java | 1 - .../org/tensorflow/op/core/ScatterNdMax.java | 3 + .../org/tensorflow/op/core/ScatterNdMin.java | 3 + .../gen/java/org/tensorflow/op/core/Send.java | 1 - .../java/org/tensorflow/op/core/Stage.java | 1 - .../org/tensorflow/op/core/StageClear.java | 1 - .../tensorflow/op/core/TensorArrayClose.java | 1 - .../core/TensorForestCreateTreeVariable.java | 1 - .../op/core/TensorForestTreeDeserialize.java | 1 - .../op/core/TensorScatterNdMax.java | 3 + .../op/core/TensorScatterNdMin.java | 3 + .../op/core/XlaSpmdFullToShardShape.java | 3 + .../op/core/XlaSpmdShardToFullShape.java | 3 + .../tensorflow/op/data/DatasetToTfRecord.java | 1 - .../tensorflow/op/data/DeleteIterator.java | 1 - .../tensorflow/op/data/DeleteMemoryCache.java | 1 - .../op/data/DeleteMultiDeviceIterator.java | 1 - .../op/data/DeserializeIterator.java | 1 - .../op/data/InitializeTableFromDataset.java | 2 + .../org/tensorflow/op/data/MakeIterator.java | 1 - .../tensorflow/op/data/RegisterDataset.java | 3 + .../op/data/ShuffleAndRepeatDataset.java | 2 +- .../tensorflow/op/data/ShuffleDataset.java | 2 +- .../op/data/experimental/CompressElement.java | 3 + .../data/experimental/DataServiceDataset.java | 3 + .../data/experimental/DatasetToTFRecord.java | 1 - .../experimental/DummyIterationCounter.java | 3 + .../StatsAggregatorSetSummaryWriter.java | 1 - .../data/experimental/UncompressElement.java | 3 + .../estimator/BoostedTreesCreateEnsemble.java | 1 - ...stedTreesCreateQuantileStreamResource.java | 1 - .../BoostedTreesDeserializeEnsemble.java | 1 - ...eesQuantileStreamResourceAddSummaries.java | 1 - ...reesQuantileStreamResourceDeserialize.java | 1 - ...ostedTreesQuantileStreamResourceFlush.java | 1 - .../estimator/BoostedTreesUpdateEnsemble.java | 1 - .../BoostedTreesUpdateEnsembleV2.java | 1 - .../tensorflow/op/image/ExtractGlimpse.java | 2 +- .../java/org/tensorflow/op/io/QueueClose.java | 1 - .../org/tensorflow/op/io/QueueEnqueue.java | 1 - .../tensorflow/op/io/QueueEnqueueMany.java | 1 - .../org/tensorflow/op/io/ReaderReset.java | 1 - .../tensorflow/op/io/ReaderRestoreState.java | 1 - .../java/org/tensorflow/op/io/WriteFile.java | 1 - .../op/linalg/BandedTriangularSolve.java | 3 + .../java/org/tensorflow/op/math/BesselI0.java | 3 + .../java/org/tensorflow/op/math/BesselI1.java | 3 + .../org/tensorflow/op/math/DenseBincount.java | 3 + .../tensorflow/op/math/special/BesselJ0.java | 3 + .../tensorflow/op/math/special/BesselJ1.java | 3 + .../tensorflow/op/math/special/BesselK0.java | 3 + .../tensorflow/op/math/special/BesselK0e.java | 3 + .../tensorflow/op/math/special/BesselK1.java | 3 + .../tensorflow/op/math/special/BesselK1e.java | 3 + .../tensorflow/op/math/special/BesselY0.java | 3 + .../tensorflow/op/math/special/BesselY1.java | 3 + .../tensorflow/op/ragged/RaggedBincount.java | 3 + .../op/ragged/RaggedCountSparseOutput.java | 3 + .../org/tensorflow/op/ragged/RaggedCross.java | 3 + .../op/random/AnonymousSeedGenerator.java | 3 + .../op/random/DeleteRandomSeedGenerator.java | 1 - .../op/random/DeleteSeedGenerator.java | 2 + .../org/tensorflow/op/random/RngSkip.java | 1 - ...StatelessParameterizedTruncatedNormal.java | 3 + .../experimental/DummySeedGenerator.java | 3 + .../op/sparse/DenseCountSparseOutput.java | 3 + .../SparseAccumulatorApplyGradient.java | 1 - .../tensorflow/op/sparse/SparseBincount.java | 3 + .../op/sparse/SparseCountSparseOutput.java | 3 + .../org/tensorflow/op/sparse/SparseCross.java | 2 +- .../op/sparse/SparseCrossHashed.java | 3 + .../op/summary/CloseSummaryWriter.java | 1 - .../op/summary/CreateSummaryDbWriter.java | 1 - .../op/summary/CreateSummaryFileWriter.java | 1 - .../op/summary/FlushSummaryWriter.java | 1 - .../tensorflow/op/summary/ImportEvent.java | 1 - .../op/summary/WriteAudioSummary.java | 1 - .../op/summary/WriteGraphSummary.java | 1 - .../op/summary/WriteHistogramSummary.java | 1 - .../op/summary/WriteImageSummary.java | 1 - .../op/summary/WriteRawProtoSummary.java | 1 - .../op/summary/WriteScalarSummary.java | 1 - .../tensorflow/op/summary/WriteSummary.java | 1 - .../op/tpu/ConfigureTPUEmbedding.java | 1 - .../tpu/EnqueueTPUEmbeddingIntegerBatch.java | 1 - .../EnqueueTPUEmbeddingRaggedTensorBatch.java | 2 + .../tpu/EnqueueTPUEmbeddingSparseBatch.java | 1 - .../EnqueueTPUEmbeddingSparseTensorBatch.java | 1 - .../org/tensorflow/op/tpu/InfeedEnqueue.java | 1 - .../tpu/InfeedEnqueuePrelinearizedBuffer.java | 1 - .../tensorflow/op/tpu/InfeedEnqueueTuple.java | 1 - .../tpu/LoadTPUEmbeddingADAMParameters.java | 1 - ...EmbeddingADAMParametersGradAccumDebug.java | 1 - .../LoadTPUEmbeddingAdadeltaParameters.java | 1 - ...ddingAdadeltaParametersGradAccumDebug.java | 1 - .../LoadTPUEmbeddingAdagradParameters.java | 1 - ...eddingAdagradParametersGradAccumDebug.java | 1 - ...TPUEmbeddingCenteredRMSPropParameters.java | 1 - .../tpu/LoadTPUEmbeddingFTRLParameters.java | 1 - ...EmbeddingFTRLParametersGradAccumDebug.java | 1 - ...TPUEmbeddingMDLAdagradLightParameters.java | 1 - .../LoadTPUEmbeddingMomentumParameters.java | 1 - ...ddingMomentumParametersGradAccumDebug.java | 1 - ...TPUEmbeddingProximalAdagradParameters.java | 1 - ...oximalAdagradParametersGradAccumDebug.java | 1 - ...oadTPUEmbeddingProximalYogiParameters.java | 2 + ...gProximalYogiParametersGradAccumDebug.java | 2 + .../LoadTPUEmbeddingRMSPropParameters.java | 1 - ...eddingRMSPropParametersGradAccumDebug.java | 1 - ...ngStochasticGradientDescentParameters.java | 1 - ...adientDescentParametersGradAccumDebug.java | 2 + .../org/tensorflow/op/tpu/OutfeedEnqueue.java | 1 - .../op/tpu/OutfeedEnqueueTuple.java | 1 - ...eveTPUEmbeddingProximalYogiParameters.java | 3 + ...gProximalYogiParametersGradAccumDebug.java | 3 + ...adientDescentParametersGradAccumDebug.java | 3 + .../op/tpu/SendTPUEmbeddingGradients.java | 1 - .../op/tpu/ShutdownDistributedTPU.java | 1 - .../op/tpu/TPUReplicateMetadata.java | 1 - .../op/train/AccumulatorApplyGradient.java | 1 - .../op/train/AccumulatorSetGlobalStep.java | 1 - .../op/train/MergeV2Checkpoints.java | 1 - .../org/tensorflow/op/train/NegTrain.java | 1 - .../ResourceAccumulatorApplyGradient.java | 1 - .../ResourceAccumulatorSetGlobalStep.java | 1 - .../op/train/ResourceApplyAdaMax.java | 1 - .../op/train/ResourceApplyAdadelta.java | 1 - .../op/train/ResourceApplyAdagrad.java | 1 - .../op/train/ResourceApplyAdagradDa.java | 1 - .../op/train/ResourceApplyAdam.java | 1 - .../train/ResourceApplyAdamWithAmsgrad.java | 1 - .../op/train/ResourceApplyAddSign.java | 1 - .../train/ResourceApplyCenteredRmsProp.java | 1 - .../op/train/ResourceApplyFtrl.java | 8 +- .../train/ResourceApplyGradientDescent.java | 1 - .../op/train/ResourceApplyKerasMomentum.java | 1 - .../op/train/ResourceApplyMomentum.java | 1 - .../op/train/ResourceApplyPowerSign.java | 1 - .../train/ResourceApplyProximalAdagrad.java | 1 - .../ResourceApplyProximalGradientDescent.java | 1 - .../op/train/ResourceApplyRmsProp.java | 1 - .../op/train/ResourceSparseApplyAdadelta.java | 1 - .../op/train/ResourceSparseApplyAdagrad.java | 1 - .../train/ResourceSparseApplyAdagradDa.java | 1 - .../train/ResourceSparseApplyAdagradV2.java | 1 - .../ResourceSparseApplyCenteredRmsProp.java | 1 - .../op/train/ResourceSparseApplyFtrl.java | 7 +- .../ResourceSparseApplyKerasMomentum.java | 1 - .../op/train/ResourceSparseApplyMomentum.java | 1 - .../ResourceSparseApplyProximalAdagrad.java | 1 - ...rceSparseApplyProximalGradientDescent.java | 1 - .../op/train/ResourceSparseApplyRmsProp.java | 1 - .../java/org/tensorflow/op/train/Save.java | 1 - .../org/tensorflow/op/train/SaveSlices.java | 1 - .../org/tensorflow/op/train/SdcaShrinkL1.java | 1 - .../gen/java/org/tensorflow/op/xla/Send.java | 1 - .../main/java/org/tensorflow/op/core/NN.java | 379 --------------- .../op/nn/SigmoidCrossEntropyWithLogits.java | 108 +++++ .../op/nn/SoftmaxCrossEntropyWithLogits.java | 214 +++++++++ .../SparseSoftmaxCrossEntropyWithLogits.java | 161 +++++++ .../main/java/org/tensorflow/types/TBool.java | 9 +- .../java/org/tensorflow/types/TString.java | 11 +- .../framework/optimizers/Momentum.java | 163 ++++--- .../framework/optimizers/Nadam.java | 295 ++++++++++++ .../framework/optimizers/Optimizer.java | 189 +++++--- .../framework/optimizers/Optimizers.java | 41 ++ .../framework/optimizers/RMSProp.java | 223 +++++---- .../schedules/PiecewiseConstantDecay.java | 58 +++ .../optimizers/schedules/PolynomialDecay.java | 127 +++++ .../framework/optimizers/MomentumTest.java | 182 +++---- .../framework}/optimizers/NadamTest.java | 307 ++++++------ .../framework/optimizers/OptimizersTest.java | 134 ++++++ .../framework/optimizers/RMSPropTest.java | 450 ++++++++++++++++++ .../schedules/PiecewiseConstantDecayTest.java | 16 + .../schedules/PolynomialDecayTest.java | 24 + .../org/tensorflow/framework}/utils/ND.java | 38 +- .../framework}/utils/TestSession.java | 261 +++++----- .../tensorflow/keras/optimizers/Nadam.java | 429 ----------------- .../keras/optimizers/OptimizerInterface.java | 49 -- .../keras/optimizers/Optimizers.java | 125 ----- .../tensorflow/keras/optimizers/RMSProp.java | 188 -------- .../org/tensorflow/keras/optimizers/SGD.java | 188 -------- .../keras/optimizers/RMSPropTest.java | 444 ----------------- 220 files changed, 2601 insertions(+), 2651 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/NN.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizers.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java rename tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java => tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java (58%) rename {tensorflow-keras/src/test/java/org/tensorflow/keras => tensorflow-framework/src/test/java/org/tensorflow/framework}/optimizers/NadamTest.java (50%) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java rename {tensorflow-keras/src/test/java/org/tensorflow/keras => tensorflow-framework/src/test/java/org/tensorflow/framework}/utils/ND.java (96%) rename {tensorflow-keras/src/test/java/org/tensorflow/keras => tensorflow-framework/src/test/java/org/tensorflow/framework}/utils/TestSession.java (82%) delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java delete mode 100644 tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java delete mode 100644 tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc index 03db4be125b..843f3bdb247 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc +++ b/tensorflow-core/tensorflow-core-api/src/bazel/op_generator/op_generator.cc @@ -514,11 +514,15 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, Javadoc name_javadoc = Javadoc::Create("The name of this op, as known by TensorFlow core engine"); string quoted_string = "\"" + op.graph_op_name() + "\""; writer.WriteFieldWithInitializer(nameVariable, PUBLIC|STATIC|FINAL, &name_javadoc, quoted_string ); - writer.EndLine(); - for (const ArgumentSpec& output : op.outputs()) { - writer.WriteField(output.var(), PRIVATE); + + if(!op.outputs().empty()) { + writer.EndLine(); + for (const ArgumentSpec& output : op.outputs()) { + writer.WriteField(output.var(), PRIVATE); + } } + RenderConstructor(op, op_class, &writer); writer.EndType(); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java index 8374a864ec2..33caf02d890 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java @@ -20,7 +20,6 @@ import java.util.List; import org.tensorflow.DataType; import org.tensorflow.Operand; -import org.tensorflow.op.core.NN; import org.tensorflow.op.nn.AvgPool; import org.tensorflow.op.nn.AvgPool3d; import org.tensorflow.op.nn.AvgPool3dGrad; @@ -84,10 +83,13 @@ import org.tensorflow.op.nn.Relu; import org.tensorflow.op.nn.Relu6; import org.tensorflow.op.nn.Selu; +import org.tensorflow.op.nn.SigmoidCrossEntropyWithLogits; import org.tensorflow.op.nn.Softmax; +import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.op.nn.Softsign; import org.tensorflow.op.nn.SpaceToBatch; import org.tensorflow.op.nn.SpaceToDepth; +import org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits; import org.tensorflow.op.nn.TopK; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -1756,49 +1758,53 @@ public Selu selu(Operand features) { } /** - * Computes sigmoid cross entropy given `logits`. + * Computes sigmoid cross entropy given logits. * *

Measures the probability error in discrete classification tasks in which each class is * independent and not mutually exclusive. For instance, one could perform multilabel * classification where a picture can contain both an elephant and a dog at the same time. * - *

For brevity, let `x = logits`, `z = labels`. The logistic loss is + *

For brevity, let x = logits, z = labels. The logistic loss in + * pseudo-code is * *

-   *      z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
-   *      = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
-   *      = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
-   *      = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
-   *      = (1 - z) * x + log(1 + exp(-x))
-   *      = x - x * z + log(1 + exp(-x))
+   *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+   *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
+   *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
+   *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
+   *   = (1 - z) * x + log(1 + exp(-x))
+   *   = x - x * z + log(1 + exp(-x))
    *  
* - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above + *

For x < 0, to avoid overflow in exp(-x), we reformulate the above * *

-   *       x - x * z + log(1 + exp(-x))
-   *       = log(exp(x)) - x * z + log(1 + exp(-x))
-   *       = - x * z + log(1 + exp(x))
+   *  x - x * z + log(1 + exp(-x))
+   *   = log(exp(x)) - x * z + log(1 + exp(-x))
+   *   = - x * z + log(1 + exp(x))
    *  
* *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent * formulation * *

-   *      max(x, 0) - x * z + log(1 + exp(-abs(x)))
+   *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
    *  
* - *

`logits` and `labels` must have the same type and shape. + *

logits and labels must have the same type and shape. + * + *

* * @param scope The TensorFlow scope * @param labels the labels * @param logits the logits of type float32 or float64 * @param the type of labels and logits * @return the component-wise logistic losses. + * @throws IllegalArgumentException if logits' and labels' do not have the same shape */ public Operand sigmoidCrossEntropyWithLogits(Operand labels, Operand logits) { - return NN.sigmoidCrossEntropyWithLogits(scope, labels, logits); + return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); } /** @@ -1817,48 +1823,54 @@ public Softmax softmax(Operand logits) { } /** - * Computes softmax cross entropy between `logits` and `labels`. + * Computes softmax cross entropy between logits and labels. * *

Measures the probability error in discrete classification tasks in which the classes are * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is * labeled with one and only one label: an image can be a dog or a truck, but not both. * - *

**NOTE:** While the classes are mutually exclusive, their probabilities need not be. All - * that is required is that each row of `labels` is a valid probability distribution. If they are - * not, the computation of the gradient will be incorrect. + *

NOTE: + * + *

While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of labels is a valid probability distribution. If they + * are not, the computation of the gradient will be incorrect. * - *

If using exclusive `labels` (wherein one and only one class is true at a time), see - * `sparse_softmax_cross_entropy_with_logits`. + *

If using exclusive labels (wherein one and only one class is true at a time), + * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} * *

Usage: * *

-   *    >>> logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]
-   *    >>> labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]
-   *    >>> tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
-   *    
+   *    Operand<TFloat32> logits =
+   *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
+   *    Operand<TFloat32> labels =
+   *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
+   *    Operand<TFloat32> output =
+   *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+   *    // output Shape = [2]
+   *    // dataType = FLOAT (1)
+   *    // values { 0.169846, 0.824745 }
    *  
* - *

Backpropagation will happen into both `logits` and `labels`. To disallow backpropagation - * into `labels`, pass label tensors through `tf.stop_gradient` before feeding it to this - * function. + *

Backpropagation will happen into both logits and labels. To + * disallow backpropagation into labels, pass label tensors through + * tf.stopGradient before feeding it to this function. * * @param scope current scope * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape `[batch_size, num_classes]`, - * each row of `labels[i]` must be a valid probability distribution. + * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] + * , each row of labels[i] must be a valid probability distribution. * @param logits Per-label activations, typically a linear output. These activation energies are * interpreted as unnormalized log probabilities. * @param axis The class dimension. -1 is the last dimension. - * @param the data type of the logits * @param the number type of the operands - * @return the softmax cross entropy loss. Its type is the same as `logits` and its shape is the - * same as `labels` except that it does not have the last dimension of `labels`. + * @return the softmax cross entropy loss. Its type is the same as logits and its + * shape is the same as labels except that it does not have the last dimension of + * labels. */ - public Operand softmaxCrossEntropyWithLogits( - Operand labels, Operand logits, int axis) { - return NN.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + public Operand softmaxCrossEntropyWithLogits( + Operand labels, Operand logits, int axis) { + return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); } /** @@ -2050,22 +2062,51 @@ public SpaceToDepth spaceToDepth(Operand input, Long blo } /** - * Computes sparse softmax cross entropy between `logits` and `labels`. + * Computes sparse softmax cross entropy between logits and labels. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

For this operation, the probability of a given label is considered exclusive. That is, soft + * classes are not allowed, and the labels vector must provide a single specific + * index for the true class for each row of logits (each minibatch entry). For soft + * softmax classification with a probability distribution for each entry, {@link + * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. + * + *

WARNING: + * + *

This op expects unscaled logits, since it performs a softmax on logits + * internally for efficiency. Do not call this op with the output of softmax, + * as it will produce incorrect results. + * + *

A common use case is to have logits of shape [batchSize, numClasses] and have + * labels of shape [batchSize], but higher dimensions are supported, in which case + * the dim-th dimension is assumed to be of size numClasses. + * logits must have the dataType of TFloat16, TFloat32 + * , or TFloat64, and labels must have the dtype of TInt32 + * or TInt64. * * @param scope current scope - * @param labels `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of `labels` and - * result) and dtype `int32` or `int64`. Each entry in `labels` must be an index in `[0, - * num_classes)`. Other values will raise an exception when this op is run on CPU, and return - * `NaN` for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape `[d_0, d_1, ..., - * d_{r-1}, num_classes]` and dtype `float16`, `float32`, or `float64`. These activation - * energies are interpreted as unnormalized log probabilities. - * @return A `Tensor` of the same shape as `labels` and of the same type as `logits` with the - * softmax cross entropy loss. + * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r + * is rank of labels and result) and the dataType is TInt32 + * or TInt64. Each entry in labels must be an index in [0, + * numClasses). Other values will raise an exception when this op is run on CPU, and + * return NaN for corresponding loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., + * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, + * or TFloat64. These activation energies are interpreted as unnormalized log + * probabilities. + * @return A Tensor of the same shape as labels and of the same type as + * logits with the softmax cross entropy loss. + * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank + * of the labels is not equal to the rank of the logits minus one. */ public Operand sparseSoftmaxCrossEntropyWithLogits( Operand labels, Operand logits) { - return NN.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); + return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Abort.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Abort.java index 53e9401dfa2..a84f2405b19 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Abort.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Abort.java @@ -104,7 +104,6 @@ public static Options exitWithoutError(Boolean exitWithoutError) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "Abort"; - private Abort(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssertThat.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssertThat.java index 950830b7462..dce70c04e5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssertThat.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssertThat.java @@ -90,7 +90,6 @@ public static Options summarize(Long summarize) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "Assert"; - private AssertThat(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignAddVariableOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignAddVariableOp.java index 5adaccf15e0..53edc808882 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignAddVariableOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignAddVariableOp.java @@ -55,7 +55,6 @@ public static AssignAddVariableOp create(Scope scope, Operand< /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AssignAddVariableOp"; - private AssignAddVariableOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignSubVariableOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignSubVariableOp.java index 4bb683c97d2..372a71b2168 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignSubVariableOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignSubVariableOp.java @@ -55,7 +55,6 @@ public static AssignSubVariableOp create(Scope scope, Operand< /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AssignSubVariableOp"; - private AssignSubVariableOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignVariableOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignVariableOp.java index 90cabd12a24..ac08d62f9a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignVariableOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/AssignVariableOp.java @@ -55,7 +55,6 @@ public static AssignVariableOp create(Scope scope, Operand /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AssignVariableOp"; - private AssignVariableOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierClose.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierClose.java index a777d684ec1..514f4f50edf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierClose.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierClose.java @@ -95,7 +95,6 @@ public static Options cancelPendingEnqueues(Boolean cancelPendingEnqueues) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BarrierClose"; - private BarrierClose(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierInsertMany.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierInsertMany.java index 31488738838..b652c11a35c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierInsertMany.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/BarrierInsertMany.java @@ -63,7 +63,6 @@ public static BarrierInsertMany create(Scope scope, Operand mutexLock) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ConsumeMutexLock"; - private ConsumeMutexLock(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ControlTrigger.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ControlTrigger.java index 721112b8204..e40715c9f2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ControlTrigger.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ControlTrigger.java @@ -48,7 +48,6 @@ public static ControlTrigger create(Scope scope) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ControlTrigger"; - private ControlTrigger(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeleteSessionTensor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeleteSessionTensor.java index 50c7615a0ff..5f92cc26ca2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeleteSessionTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeleteSessionTensor.java @@ -50,7 +50,6 @@ public static DeleteSessionTensor create(Scope scope, Operand handle) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeleteSessionTensor"; - private DeleteSessionTensor(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DestroyResourceOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DestroyResourceOp.java index e1958682ee1..8a427166874 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DestroyResourceOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DestroyResourceOp.java @@ -88,7 +88,6 @@ public static Options ignoreLookupError(Boolean ignoreLookupError) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DestroyResourceOp"; - private DestroyResourceOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeviceIndex.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeviceIndex.java index 26f984e840d..f033d3fcc9d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeviceIndex.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DeviceIndex.java @@ -68,6 +68,9 @@ public Output asOutput() { return index; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DeviceIndex"; + private Output index; private DeviceIndex(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTable.java index 5de2ca6ff07..48662ed420d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTable.java @@ -54,7 +54,6 @@ public static InitializeTable create(Scope sc /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InitializeTableV2"; - private InitializeTable(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTableFromTextFile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTableFromTextFile.java index 0a88cea3ef2..2050c4d8628 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTableFromTextFile.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/InitializeTableFromTextFile.java @@ -121,7 +121,6 @@ public static Options delimiter(String delimiter) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InitializeTableFromTextFileV2"; - private InitializeTableFromTextFile(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableImport.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableImport.java index a94393a50f1..9884a40e3cb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableImport.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableImport.java @@ -57,7 +57,6 @@ public static LookupTableImport create(Scope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LookupTableImportV2"; - private LookupTableImport(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableInsert.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableInsert.java index c31784ea942..0f09ae25d1b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableInsert.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableInsert.java @@ -57,7 +57,6 @@ public static LookupTableInsert create(Scope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LookupTableInsertV2"; - private LookupTableInsert(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableRemove.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableRemove.java index 584e7e1325c..41463ad7539 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableRemove.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/LookupTableRemove.java @@ -54,7 +54,6 @@ public static LookupTableRemove create(Scope scope, Operand /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LookupTableRemoveV2"; - private LookupTableRemove(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java index ea7581ef2c7..bad1e90554f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapClear.java @@ -145,7 +145,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MapClear"; - private MapClear(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java index 5d72ce8f22f..9291b32d53b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/MapStage.java @@ -160,7 +160,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MapStage"; - private MapStage(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NoOp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NoOp.java index 862aabcd795..922b5d55ce3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NoOp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/NoOp.java @@ -46,7 +46,6 @@ public static NoOp create(Scope scope) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "NoOp"; - private NoOp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java index 29f4133ce09..05a1b7ab984 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapClear.java @@ -145,7 +145,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OrderedMapClear"; - private OrderedMapClear(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java index b51f94c148a..7e02973e3c6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/OrderedMapStage.java @@ -162,7 +162,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OrderedMapStage"; - private OrderedMapStage(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Print.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Print.java index 3e96c00d369..52b933329a0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Print.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Print.java @@ -105,7 +105,6 @@ public static Options end(String end) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "PrintV2"; - private Print(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterAdd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterAdd.java index 5383062823b..0966dd5fcc4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterAdd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterAdd.java @@ -75,7 +75,6 @@ public static ResourceScatterAdd create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterAdd"; - private ResourceScatterAdd(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterDiv.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterDiv.java index ed950863242..9560bddf284 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterDiv.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterDiv.java @@ -75,7 +75,6 @@ public static ResourceScatterDiv create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterDiv"; - private ResourceScatterDiv(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMax.java index 7553fab4812..ce952ee19ba 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMax.java @@ -75,7 +75,6 @@ public static ResourceScatterMax create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterMax"; - private ResourceScatterMax(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMin.java index 68518b4c640..51ec6b7637e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMin.java @@ -75,7 +75,6 @@ public static ResourceScatterMin create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterMin"; - private ResourceScatterMin(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMul.java index f52b338de57..2d5f71e006d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterMul.java @@ -75,7 +75,6 @@ public static ResourceScatterMul create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterMul"; - private ResourceScatterMul(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdAdd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdAdd.java index 5abfcbea5ee..11e45c33098 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdAdd.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdAdd.java @@ -125,7 +125,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterNdAdd"; - private ResourceScatterNdAdd(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMax.java index e24e3d68fef..82c1f766308 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMax.java @@ -91,6 +91,8 @@ public static Options useLocking(Boolean useLocking) { return new Options().useLocking(useLocking); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ResourceScatterNdMax"; private ResourceScatterNdMax(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMin.java index 3ffc78afa87..88e107c65c7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdMin.java @@ -91,6 +91,8 @@ public static Options useLocking(Boolean useLocking) { return new Options().useLocking(useLocking); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ResourceScatterNdMin"; private ResourceScatterNdMin(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdSub.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdSub.java index c4b6060d611..267099b7cfc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdSub.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdSub.java @@ -125,7 +125,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterNdSub"; - private ResourceScatterNdSub(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdUpdate.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdUpdate.java index b47fb4a1367..4a1e875bc97 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdUpdate.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterNdUpdate.java @@ -127,7 +127,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterNdUpdate"; - private ResourceScatterNdUpdate(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterSub.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterSub.java index 2559ff21a93..7b772fab997 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterSub.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterSub.java @@ -75,7 +75,6 @@ public static ResourceScatterSub create(Sco /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterSub"; - private ResourceScatterSub(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterUpdate.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterUpdate.java index eff04c6c08a..067ddf5f205 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterUpdate.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceScatterUpdate.java @@ -66,7 +66,6 @@ public static ResourceScatterUpdate create( /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceScatterUpdate"; - private ResourceScatterUpdate(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceStridedSliceAssign.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceStridedSliceAssign.java index 2002140573b..4deb4c55f64 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceStridedSliceAssign.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ResourceStridedSliceAssign.java @@ -176,7 +176,6 @@ public static Options shrinkAxisMask(Long shrinkAxisMask) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceStridedSliceAssign"; - private ResourceStridedSliceAssign(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMax.java index 851cbb16cf4..da94c783cae 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMax.java @@ -107,6 +107,9 @@ public Output asOutput() { return outputRef; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ScatterNdMax"; + private Output outputRef; private ScatterNdMax(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMin.java index a3e3d4c9790..5aea70bc929 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/ScatterNdMin.java @@ -107,6 +107,9 @@ public Output asOutput() { return outputRef; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ScatterNdMin"; + private Output outputRef; private ScatterNdMin(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Send.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Send.java index bf3db1cd88a..d679b85319a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Send.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Send.java @@ -97,7 +97,6 @@ public static Options clientTerminated(Boolean clientTerminated) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "Send"; - private Send(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Stage.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Stage.java index 408b6eca252..526462b02f4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Stage.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Stage.java @@ -151,7 +151,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "Stage"; - private Stage(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java index 755e7ab72d9..60e51559f74 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/StageClear.java @@ -145,7 +145,6 @@ public static Options sharedName(String sharedName) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "StageClear"; - private StageClear(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayClose.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayClose.java index a16e856ae72..62180e8e5ff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayClose.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorArrayClose.java @@ -52,7 +52,6 @@ public static TensorArrayClose create(Scope scope, Operand handle) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "TensorArrayCloseV3"; - private TensorArrayClose(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestCreateTreeVariable.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestCreateTreeVariable.java index e647f58b2f3..5ca6ffa1cd0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestCreateTreeVariable.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestCreateTreeVariable.java @@ -51,7 +51,6 @@ public static TensorForestCreateTreeVariable create(Scope scope, Operand tree /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "TensorForestCreateTreeVariable"; - private TensorForestCreateTreeVariable(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestTreeDeserialize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestTreeDeserialize.java index 5fb704b2361..a5e1638035e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestTreeDeserialize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorForestTreeDeserialize.java @@ -51,7 +51,6 @@ public static TensorForestTreeDeserialize create(Scope scope, Operand treeHan /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "TensorForestTreeDeserialize"; - private TensorForestTreeDeserialize(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMax.java index d040cf0639a..a14b20195af 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMax.java @@ -65,6 +65,9 @@ public Output asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "TensorScatterMax"; + private Output output; private TensorScatterNdMax(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMin.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMin.java index 797878d9c76..b202b72eebd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMin.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/TensorScatterNdMin.java @@ -65,6 +65,9 @@ public Output asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "TensorScatterMin"; + private Output output; private TensorScatterNdMin(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java index 51f2c2b5dde..6615d2ef9f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdFullToShardShape.java @@ -68,6 +68,9 @@ public Output asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "XlaSpmdFullToShardShape"; + private Output output; private XlaSpmdFullToShardShape(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java index 5a120fb6fb8..75e31c7c317 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/XlaSpmdShardToFullShape.java @@ -70,6 +70,9 @@ public Output asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "XlaSpmdShardToFullShape"; + private Output output; private XlaSpmdShardToFullShape(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java index 41617c9b690..114e11074dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetToTfRecord.java @@ -54,7 +54,6 @@ public static DatasetToTfRecord create(Scope scope, Operand inputDataset, Ope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DatasetToTFRecord"; - private DatasetToTfRecord(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteIterator.java index ec3629a8eb7..69f3af096bb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteIterator.java @@ -51,7 +51,6 @@ public static DeleteIterator create(Scope scope, Operand handle, Operand d /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeleteIterator"; - private DeleteIterator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMemoryCache.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMemoryCache.java index 3c0f37dc409..21c33030b66 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMemoryCache.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeleteMemoryCache.java @@ -49,7 +49,6 @@ public static DeleteMemoryCache create(Scope scope, Operand handle, Operand multiDevi /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeleteMultiDeviceIterator"; - private DeleteMultiDeviceIterator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeserializeIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeserializeIterator.java index 528fd09bc53..4f772fd5028 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeserializeIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DeserializeIterator.java @@ -52,7 +52,6 @@ public static DeserializeIterator create(Scope scope, Operand resourceHandle, /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeserializeIterator"; - private DeserializeIterator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java index 527c951377b..05a263a4ec1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/InitializeTableFromDataset.java @@ -46,6 +46,8 @@ public static InitializeTableFromDataset create(Scope scope, Operand tableHan return new InitializeTableFromDataset(opBuilder.build()); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "InitializeTableFromDataset"; private InitializeTableFromDataset(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MakeIterator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MakeIterator.java index 685574a92d5..4aace25184e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MakeIterator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/MakeIterator.java @@ -54,7 +54,6 @@ public static MakeIterator create(Scope scope, Operand dataset, Operand it /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MakeIterator"; - private MakeIterator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java index 7e0695768c2..5705413165c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/RegisterDataset.java @@ -65,6 +65,9 @@ public Output asOutput() { return datasetId; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RegisterDataset"; + private Output datasetId; private RegisterDataset(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java index 1f6d2697497..c5703e8e85c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleAndRepeatDataset.java @@ -119,7 +119,7 @@ public Output asOutput() { } /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ShuffleAndRepeatDataset"; + public static final String OP_NAME = "ShuffleAndRepeatDatasetV2"; private Output handle; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java index ce3de6f787f..3dd522e319c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ShuffleDataset.java @@ -117,7 +117,7 @@ public Output asOutput() { } /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ShuffleDatasetV2"; + public static final String OP_NAME = "ShuffleDatasetV3"; private Output handle; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java index e56a8cde614..9e4bfb34a8b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/CompressElement.java @@ -60,6 +60,9 @@ public Output asOutput() { return (Output) compressed; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "CompressElement"; + private Output compressed; private CompressElement(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java index 4623ec2ea5d..b3e853f9de4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DataServiceDataset.java @@ -122,6 +122,9 @@ public Output asOutput() { return (Output) handle; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DataServiceDataset"; + private Output handle; private DataServiceDataset(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DatasetToTFRecord.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DatasetToTFRecord.java index ee0f494fa73..6e0a0b8f2dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DatasetToTFRecord.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DatasetToTFRecord.java @@ -54,7 +54,6 @@ public static DatasetToTFRecord create(Scope scope, Operand inputDataset, Ope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ExperimentalDatasetToTFRecord"; - private DatasetToTFRecord(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java index b2febce6f81..72f83847285 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/DummyIterationCounter.java @@ -56,6 +56,9 @@ public Output asOutput() { return (Output) handle; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DummyIterationCounter"; + private Output handle; private DummyIterationCounter(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java index ad63ff056d8..1af246d8313 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/StatsAggregatorSetSummaryWriter.java @@ -50,7 +50,6 @@ public static StatsAggregatorSetSummaryWriter create(Scope scope, Operand sta /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "StatsAggregatorSetSummaryWriter"; - private StatsAggregatorSetSummaryWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java index 5d94d8699ab..c5732154e94 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/experimental/UncompressElement.java @@ -76,6 +76,9 @@ public Iterator> iterator() { return (Iterator) components.iterator(); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "UncompressElement"; + private List> components; private UncompressElement(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateEnsemble.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateEnsemble.java index 04613900567..8841988b36d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateEnsemble.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateEnsemble.java @@ -54,7 +54,6 @@ public static BoostedTreesCreateEnsemble create(Scope scope, Operand treeEnse /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesCreateEnsemble"; - private BoostedTreesCreateEnsemble(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateQuantileStreamResource.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateQuantileStreamResource.java index 59362970a58..802a61ecb2a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateQuantileStreamResource.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesCreateQuantileStreamResource.java @@ -88,7 +88,6 @@ public static Options maxElements(Long maxElements) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesCreateQuantileStreamResource"; - private BoostedTreesCreateQuantileStreamResource(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesDeserializeEnsemble.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesDeserializeEnsemble.java index 6fd83d5785d..15371fb4df9 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesDeserializeEnsemble.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesDeserializeEnsemble.java @@ -56,7 +56,6 @@ public static BoostedTreesDeserializeEnsemble create(Scope scope, Operand tre /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesDeserializeEnsemble"; - private BoostedTreesDeserializeEnsemble(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceAddSummaries.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceAddSummaries.java index 76480c4be6d..418ff3b2ff6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceAddSummaries.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceAddSummaries.java @@ -56,7 +56,6 @@ public static BoostedTreesQuantileStreamResourceAddSummaries create(Scope scope, /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesQuantileStreamResourceAddSummaries"; - private BoostedTreesQuantileStreamResourceAddSummaries(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceDeserialize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceDeserialize.java index 82066e267d2..6efb58ed60c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceDeserialize.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceDeserialize.java @@ -54,7 +54,6 @@ public static BoostedTreesQuantileStreamResourceDeserialize create(Scope scope, /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesQuantileStreamResourceDeserialize"; - private BoostedTreesQuantileStreamResourceDeserialize(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceFlush.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceFlush.java index 359b7b63ff6..cc10434a582 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceFlush.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesQuantileStreamResourceFlush.java @@ -97,7 +97,6 @@ public static Options generateQuantiles(Boolean generateQuantiles) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesQuantileStreamResourceFlush"; - private BoostedTreesQuantileStreamResourceFlush(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsemble.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsemble.java index c1b7bb44559..e6ddcf3d2da 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsemble.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsemble.java @@ -79,7 +79,6 @@ public static BoostedTreesUpdateEnsemble create(Scope scope, Operand treeEnse /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesUpdateEnsemble"; - private BoostedTreesUpdateEnsemble(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsembleV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsembleV2.java index afd9d646c2f..ceaff116fd1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsembleV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/estimator/BoostedTreesUpdateEnsembleV2.java @@ -118,7 +118,6 @@ public static Options logitsDimension(Long logitsDimension) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "BoostedTreesUpdateEnsembleV2"; - private BoostedTreesUpdateEnsembleV2(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractGlimpse.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractGlimpse.java index 172b24f74ac..05bc1d924b3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractGlimpse.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/image/ExtractGlimpse.java @@ -199,7 +199,7 @@ public Output asOutput() { } /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ExtractGlimpse"; + public static final String OP_NAME = "ExtractGlimpseV2"; private Output glimpse; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueClose.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueClose.java index fb55f708140..ea7791f143e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueClose.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueClose.java @@ -91,7 +91,6 @@ public static Options cancelPendingEnqueues(Boolean cancelPendingEnqueues) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "QueueCloseV2"; - private QueueClose(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueue.java index 546981e8abf..a159b0cd17c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueue.java @@ -96,7 +96,6 @@ public static Options timeoutMs(Long timeoutMs) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "QueueEnqueueV2"; - private QueueEnqueue(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueueMany.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueueMany.java index 48df5d3b9d3..b1f9cbd6807 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueueMany.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/QueueEnqueueMany.java @@ -101,7 +101,6 @@ public static Options timeoutMs(Long timeoutMs) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "QueueEnqueueManyV2"; - private QueueEnqueueMany(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderReset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderReset.java index 6e3de01134b..243d4a72080 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderReset.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderReset.java @@ -49,7 +49,6 @@ public static ReaderReset create(Scope scope, Operand readerHandle) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ReaderResetV2"; - private ReaderReset(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderRestoreState.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderRestoreState.java index b0abea1257c..431ba079ffc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderRestoreState.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/ReaderRestoreState.java @@ -56,7 +56,6 @@ public static ReaderRestoreState create(Scope scope, Operand readerHandle, Op /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ReaderRestoreStateV2"; - private ReaderRestoreState(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/WriteFile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/WriteFile.java index d9fba243fc3..d1c9dd9b9c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/WriteFile.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/io/WriteFile.java @@ -54,7 +54,6 @@ public static WriteFile create(Scope scope, Operand filename, Operand asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BandedTriangularSolve"; + private Output output; private BandedTriangularSolve(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java index 53e1ac83c32..45dcd2b8e4c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI0.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselI0"; + private Output y; private BesselI0(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java index 638f6b06972..148758aa5a4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/BesselI1.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselI1"; + private Output y; private BesselI1(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java index 165be081102..e38d559f4ae 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/DenseBincount.java @@ -112,6 +112,9 @@ public Output asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DenseBincount"; + private Output output; private DenseBincount(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java index bc73a0c9c02..8d2184a49cb 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ0.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselJ0"; + private Output y; private BesselJ0(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java index 4fd21c42288..d8f9621a36c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselJ1.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselJ1"; + private Output y; private BesselJ1(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java index 8f3c540b185..eaae243f83f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselK0"; + private Output y; private BesselK0(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java index 1a8f9761c08..c57ae64e233 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK0e.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselK0e"; + private Output y; private BesselK0e(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java index bcaaf6f6f9c..1858d25fe3d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselK1"; + private Output y; private BesselK1(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java index c6590805d54..e4a5cc23efd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselK1e.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselK1e"; + private Output y; private BesselK1e(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java index 86843a30939..9228d1b6145 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY0.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselY0"; + private Output y; private BesselY0(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java index 2cdc4ad7df0..0461416b808 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/math/special/BesselY1.java @@ -59,6 +59,9 @@ public Output asOutput() { return y; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "BesselY1"; + private Output y; private BesselY1(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java index fc1636d8d64..1e0224aa9ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedBincount.java @@ -115,6 +115,9 @@ public Output asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RaggedBincount"; + private Output output; private RaggedBincount(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java index 07b364f6ebb..4829e49488b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCountSparseOutput.java @@ -140,6 +140,9 @@ public Output outputDenseShape() { return outputDenseShape; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RaggedCountSparseOutput"; + private Output outputIndices; private Output outputValues; private Output outputDenseShape; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java index fa6e811969b..9ea32878257 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/ragged/RaggedCross.java @@ -94,6 +94,9 @@ public Output outputRowSplits() { return outputRowSplits; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RaggedCross"; + private Output outputValues; private Output outputRowSplits; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/AnonymousSeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/AnonymousSeedGenerator.java index f55c3222977..c724bb6d110 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/AnonymousSeedGenerator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/AnonymousSeedGenerator.java @@ -63,6 +63,9 @@ public Output deleter() { return deleter; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "AnonymousSeedGenerator"; + private Output handle; private Output deleter; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteRandomSeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteRandomSeedGenerator.java index 9bc34d98ebb..23b154f9d75 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteRandomSeedGenerator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteRandomSeedGenerator.java @@ -49,7 +49,6 @@ public static DeleteRandomSeedGenerator create(Scope scope, Operand handle, O /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "DeleteRandomSeedGenerator"; - private DeleteRandomSeedGenerator(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteSeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteSeedGenerator.java index 2872ea12aff..16982946d1f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteSeedGenerator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/DeleteSeedGenerator.java @@ -46,6 +46,8 @@ public static DeleteSeedGenerator create(Scope scope, Operand handle, Operand return new DeleteSeedGenerator(opBuilder.build()); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DeleteSeedGenerator"; private DeleteSeedGenerator(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RngSkip.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RngSkip.java index e3411c3b989..f41cff35b04 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RngSkip.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/RngSkip.java @@ -58,7 +58,6 @@ public static RngSkip create(Scope scope, Operand resource, Operand a /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "RngSkip"; - private RngSkip(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java index 053db14c986..179160463c7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/StatelessParameterizedTruncatedNormal.java @@ -72,6 +72,9 @@ public Output asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "StatelessParameterizedTruncatedNormal"; + private Output output; private StatelessParameterizedTruncatedNormal(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java index 92e58ba293f..dd537fa2d68 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/random/experimental/DummySeedGenerator.java @@ -56,6 +56,9 @@ public Output asOutput() { return (Output) handle; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DummySeedGenerator"; + private Output handle; private DummySeedGenerator(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java index 62c489cab7b..ed390a7ba47 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/DenseCountSparseOutput.java @@ -132,6 +132,9 @@ public Output outputDenseShape() { return outputDenseShape; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "DenseCountSparseOutput"; + private Output outputIndices; private Output outputValues; private Output outputDenseShape; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorApplyGradient.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorApplyGradient.java index 3f36812ea57..328fe0c49ea 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorApplyGradient.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseAccumulatorApplyGradient.java @@ -69,7 +69,6 @@ public static SparseAccumulatorApplyGradient create(Scope scop /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SparseAccumulatorApplyGradient"; - private SparseAccumulatorApplyGradient(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java index 7902e8544dd..344e27f1346 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseBincount.java @@ -117,6 +117,9 @@ public Output asOutput() { return output; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "SparseBincount"; + private Output output; private SparseBincount(Operation operation) { diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java index 36230bc774e..5e5566db5ec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCountSparseOutput.java @@ -136,6 +136,9 @@ public Output outputDenseShape() { return outputDenseShape; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "SparseCountSparseOutput"; + private Output outputIndices; private Output outputValues; private Output outputDenseShape; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCross.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCross.java index 06113f0315b..1cd471349c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCross.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCross.java @@ -118,7 +118,7 @@ public Output outputShape() { } /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "SparseCross"; + public static final String OP_NAME = "SparseCrossV2"; private Output outputIndices; private Output outputValues; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCrossHashed.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCrossHashed.java index 9e7cc9b1e6c..2fc6976079e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCrossHashed.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/sparse/SparseCrossHashed.java @@ -122,6 +122,9 @@ public Output outputShape() { return outputShape; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "SparseCrossHashed"; + private Output outputIndices; private Output outputValues; private Output outputShape; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CloseSummaryWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CloseSummaryWriter.java index ff9735f0b07..f5d95d50976 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CloseSummaryWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CloseSummaryWriter.java @@ -47,7 +47,6 @@ public static CloseSummaryWriter create(Scope scope, Operand writer) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "CloseSummaryWriter"; - private CloseSummaryWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryDbWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryDbWriter.java index 61e7405f74d..8e40aa798d6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryDbWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryDbWriter.java @@ -56,7 +56,6 @@ public static CreateSummaryDbWriter create(Scope scope, Operand writer, Opera /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "CreateSummaryDbWriter"; - private CreateSummaryDbWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryFileWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryFileWriter.java index d113ebcf3f6..e429fab20e2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryFileWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/CreateSummaryFileWriter.java @@ -57,7 +57,6 @@ public static CreateSummaryFileWriter create(Scope scope, Operand writer, Ope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "CreateSummaryFileWriter"; - private CreateSummaryFileWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/FlushSummaryWriter.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/FlushSummaryWriter.java index 6b1e610c632..e1586542972 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/FlushSummaryWriter.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/FlushSummaryWriter.java @@ -47,7 +47,6 @@ public static FlushSummaryWriter create(Scope scope, Operand writer) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "FlushSummaryWriter"; - private FlushSummaryWriter(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImportEvent.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImportEvent.java index 9b6dc173abe..7bd97de571e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImportEvent.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/ImportEvent.java @@ -50,7 +50,6 @@ public static ImportEvent create(Scope scope, Operand writer, Operand writer, Operand WriteHistogramSummary create(Scope scope, Oper /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteHistogramSummary"; - private WriteHistogramSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java index 286d584d695..757ddf59a1c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteImageSummary.java @@ -94,7 +94,6 @@ public static Options maxImages(Long maxImages) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteImageSummary"; - private WriteImageSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteRawProtoSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteRawProtoSummary.java index 524b56bed7a..75499c1ff69 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteRawProtoSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteRawProtoSummary.java @@ -53,7 +53,6 @@ public static WriteRawProtoSummary create(Scope scope, Operand writer, Operan /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteRawProtoSummary"; - private WriteRawProtoSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java index 2317db7bdeb..f173651001a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteScalarSummary.java @@ -57,7 +57,6 @@ public static WriteScalarSummary create(Scope scope, Operand /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteScalarSummary"; - private WriteScalarSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteSummary.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteSummary.java index 6d257f948e8..5404e593f27 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteSummary.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/summary/WriteSummary.java @@ -58,7 +58,6 @@ public static WriteSummary create(Scope scope, Operand writ /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "WriteSummary"; - private WriteSummary(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConfigureTPUEmbedding.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConfigureTPUEmbedding.java index 1905a3082b3..76bccd51f83 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConfigureTPUEmbedding.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConfigureTPUEmbedding.java @@ -48,7 +48,6 @@ public static ConfigureTPUEmbedding create(Scope scope, String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ConfigureTPUEmbedding"; - private ConfigureTPUEmbedding(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingIntegerBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingIntegerBatch.java index 4198c38e11c..0a1a80c7a0a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingIntegerBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingIntegerBatch.java @@ -93,7 +93,6 @@ public static Options deviceOrdinal(Long deviceOrdinal) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "EnqueueTPUEmbeddingIntegerBatch"; - private EnqueueTPUEmbeddingIntegerBatch(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java index c605dafbc87..bf4da86d05d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingRaggedTensorBatch.java @@ -177,6 +177,8 @@ public static Options maxSequenceLengths(List maxSequenceLengths) { return new Options().maxSequenceLengths(maxSequenceLengths); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "EnqueueTPUEmbeddingRaggedTensorBatch"; private EnqueueTPUEmbeddingRaggedTensorBatch(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java index 23288018938..2cb7dfb674b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseBatch.java @@ -146,7 +146,6 @@ public static Options combiners(List combiners) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "EnqueueTPUEmbeddingSparseBatch"; - private EnqueueTPUEmbeddingSparseBatch(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java index 59018e1b3e5..3d93c6a0f71 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/EnqueueTPUEmbeddingSparseTensorBatch.java @@ -178,7 +178,6 @@ public static Options maxSequenceLengths(List maxSequenceLengths) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "EnqueueTPUEmbeddingSparseTensorBatch"; - private EnqueueTPUEmbeddingSparseTensorBatch(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueue.java index 9c79df444e3..391d51a9ab0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueue.java @@ -135,7 +135,6 @@ public static Options deviceOrdinal(Long deviceOrdinal) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InfeedEnqueue"; - private InfeedEnqueue(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueuePrelinearizedBuffer.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueuePrelinearizedBuffer.java index b1d32f70ec9..9344352791c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueuePrelinearizedBuffer.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueuePrelinearizedBuffer.java @@ -84,7 +84,6 @@ public static Options deviceOrdinal(Long deviceOrdinal) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InfeedEnqueuePrelinearizedBuffer"; - private InfeedEnqueuePrelinearizedBuffer(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueueTuple.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueueTuple.java index d2a95d84244..b439df84f71 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueueTuple.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/InfeedEnqueueTuple.java @@ -124,7 +124,6 @@ public static Options deviceOrdinal(Long deviceOrdinal) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "InfeedEnqueueTuple"; - private InfeedEnqueueTuple(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParameters.java index 9e60fae350e..744688cee23 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParameters.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingADAMParameters"; - private LoadTPUEmbeddingADAMParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParametersGradAccumDebug.java index 58cfa5cf465..63df2e6aa79 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingADAMParametersGradAccumDebug.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingADAMParametersGradAccumDebug"; - private LoadTPUEmbeddingADAMParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParameters.java index e4f4228f0f1..43535a2aff8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParameters.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingAdadeltaParameters"; - private LoadTPUEmbeddingAdadeltaParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.java index 76af15dc0b6..ce1b759ee60 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdadeltaParametersGradAccumDebug.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingAdadeltaParametersGradAccumDebug"; - private LoadTPUEmbeddingAdadeltaParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParameters.java index dc4f5c62341..f9e16c5b5d6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParameters.java @@ -133,7 +133,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingAdagradParameters"; - private LoadTPUEmbeddingAdagradParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParametersGradAccumDebug.java index 6551f875f2d..7f8df653745 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingAdagradParametersGradAccumDebug.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingAdagradParametersGradAccumDebug"; - private LoadTPUEmbeddingAdagradParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingCenteredRMSPropParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingCenteredRMSPropParameters.java index d4a0103654c..f0b704cfaa1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingCenteredRMSPropParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingCenteredRMSPropParameters.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingCenteredRMSPropParameters"; - private LoadTPUEmbeddingCenteredRMSPropParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParameters.java index a65301f6348..c96edf58894 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParameters.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingFTRLParameters"; - private LoadTPUEmbeddingFTRLParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParametersGradAccumDebug.java index 5a1c165428d..f0a85bd945a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingFTRLParametersGradAccumDebug.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingFTRLParametersGradAccumDebug"; - private LoadTPUEmbeddingFTRLParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMDLAdagradLightParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMDLAdagradLightParameters.java index 407cf842f19..f418a70cc8c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMDLAdagradLightParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMDLAdagradLightParameters.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingMDLAdagradLightParameters"; - private LoadTPUEmbeddingMDLAdagradLightParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParameters.java index 35b8479749b..718bdc24f5c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParameters.java @@ -133,7 +133,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingMomentumParameters"; - private LoadTPUEmbeddingMomentumParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParametersGradAccumDebug.java index babc2de15fd..424c3c846c0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingMomentumParametersGradAccumDebug.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingMomentumParametersGradAccumDebug"; - private LoadTPUEmbeddingMomentumParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParameters.java index 0ebad625abe..7b7265e9b82 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParameters.java @@ -133,7 +133,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingProximalAdagradParameters"; - private LoadTPUEmbeddingProximalAdagradParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.java index 80b05d47203..c18d2cf22f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug"; - private LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParameters.java index 651c8e189c4..2a96916c4f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParameters.java @@ -125,6 +125,8 @@ public static Options config(String config) { return new Options().config(config); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "LoadTPUEmbeddingProximalYogiParameters"; private LoadTPUEmbeddingProximalYogiParameters(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.java index 274accba1c9..e863dc554d6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingProximalYogiParametersGradAccumDebug.java @@ -127,6 +127,8 @@ public static Options config(String config) { return new Options().config(config); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "LoadTPUEmbeddingProximalYogiParametersGradAccumDebug"; private LoadTPUEmbeddingProximalYogiParametersGradAccumDebug(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParameters.java index d0e39d22edb..1f747282d55 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParameters.java @@ -135,7 +135,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingRMSPropParameters"; - private LoadTPUEmbeddingRMSPropParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.java index 98f1043a768..a7c8ed4812c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingRMSPropParametersGradAccumDebug.java @@ -137,7 +137,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingRMSPropParametersGradAccumDebug"; - private LoadTPUEmbeddingRMSPropParametersGradAccumDebug(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParameters.java index ca881823239..769d3436eda 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParameters.java @@ -131,7 +131,6 @@ public static Options config(String config) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "LoadTPUEmbeddingStochasticGradientDescentParameters"; - private LoadTPUEmbeddingStochasticGradientDescentParameters(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java index 76a13a489ad..e408844e484 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java @@ -130,6 +130,8 @@ public static Options config(String config) { return new Options().config(config); } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"; private LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueue.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueue.java index 46ee54430d9..5b5f059a81e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueue.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueue.java @@ -49,7 +49,6 @@ public static OutfeedEnqueue create(Scope scope, Operand in /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OutfeedEnqueue"; - private OutfeedEnqueue(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueueTuple.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueueTuple.java index 25d110f1114..8bfd04b9a2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueueTuple.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/OutfeedEnqueueTuple.java @@ -50,7 +50,6 @@ public static OutfeedEnqueueTuple create(Scope scope, Iterable> input /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "OutfeedEnqueueTuple"; - private OutfeedEnqueueTuple(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParameters.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParameters.java index eaee4fdabc4..a46cae90359 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParameters.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParameters.java @@ -137,6 +137,9 @@ public Output m() { return m; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RetrieveTPUEmbeddingProximalYogiParameters"; + private Output parameters; private Output v; private Output m; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.java index ec57d8cb424..55535a573f6 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug.java @@ -143,6 +143,9 @@ public Output gradientAccumulators() { return gradientAccumulators; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug"; + private Output parameters; private Output v; private Output m; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java index f649b4d01fa..9f35ffcd8e0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug.java @@ -139,6 +139,9 @@ public Output gradientAccumulators() { return gradientAccumulators; } + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug"; + private Output parameters; private Output gradientAccumulators; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/SendTPUEmbeddingGradients.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/SendTPUEmbeddingGradients.java index 25d25c1d5bf..482080bde5d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/SendTPUEmbeddingGradients.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/SendTPUEmbeddingGradients.java @@ -64,7 +64,6 @@ public static SendTPUEmbeddingGradients create(Scope scope, Iterable AccumulatorApplyGradient create(Scope scope, Ope /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AccumulatorApplyGradient"; - private AccumulatorApplyGradient(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorSetGlobalStep.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorSetGlobalStep.java index b57bb702669..9039d3a654d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorSetGlobalStep.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/AccumulatorSetGlobalStep.java @@ -56,7 +56,6 @@ public static AccumulatorSetGlobalStep create(Scope scope, Operand hand /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "AccumulatorSetGlobalStep"; - private AccumulatorSetGlobalStep(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/MergeV2Checkpoints.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/MergeV2Checkpoints.java index 4fe4a27171a..986553a2d8e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/MergeV2Checkpoints.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/MergeV2Checkpoints.java @@ -96,7 +96,6 @@ public static Options deleteOldDirs(Boolean deleteOldDirs) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "MergeV2Checkpoints"; - private MergeV2Checkpoints(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/NegTrain.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/NegTrain.java index f43928961a9..b3e5316ad7e 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/NegTrain.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/NegTrain.java @@ -68,7 +68,6 @@ public static NegTrain create(Scope scope, Operand wIn, Operand ResourceAccumulatorApplyGradient create(Scope sc /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceAccumulatorApplyGradient"; - private ResourceAccumulatorApplyGradient(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorSetGlobalStep.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorSetGlobalStep.java index e04784aef48..37570909340 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorSetGlobalStep.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceAccumulatorSetGlobalStep.java @@ -54,7 +54,6 @@ public static ResourceAccumulatorSetGlobalStep create(Scope scope, Operand ha /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceAccumulatorSetGlobalStep"; - private ResourceAccumulatorSetGlobalStep(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdaMax.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdaMax.java index 5efab216739..169da75fecd 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdaMax.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdaMax.java @@ -107,7 +107,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdaMax"; - private ResourceApplyAdaMax(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdadelta.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdadelta.java index 0121155a2ef..4323a39de45 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdadelta.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdadelta.java @@ -103,7 +103,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdadelta"; - private ResourceApplyAdadelta(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagrad.java index 8868193464f..b60ad1fc0e0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagrad.java @@ -117,7 +117,6 @@ public static Options updateSlots(Boolean updateSlots) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdagradV2"; - private ResourceApplyAdagrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagradDa.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagradDa.java index 7f5b26056ac..7c6f06634ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagradDa.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdagradDa.java @@ -101,7 +101,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdagradDA"; - private ResourceApplyAdagradDa(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdam.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdam.java index 20a07b4865d..4a1aea5d355 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdam.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdam.java @@ -130,7 +130,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdam"; - private ResourceApplyAdam(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdamWithAmsgrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdamWithAmsgrad.java index ec13f10038d..a436bc7fdd2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdamWithAmsgrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAdamWithAmsgrad.java @@ -114,7 +114,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAdamWithAmsgrad"; - private ResourceApplyAdamWithAmsgrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAddSign.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAddSign.java index e64354ec3bb..85c9c587979 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAddSign.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyAddSign.java @@ -104,7 +104,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyAddSign"; - private ResourceApplyAddSign(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyCenteredRmsProp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyCenteredRmsProp.java index c4c8d7c9ec7..6fc3a8a02ff 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyCenteredRmsProp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyCenteredRmsProp.java @@ -123,7 +123,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyCenteredRMSProp"; - private ResourceApplyCenteredRmsProp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyFtrl.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyFtrl.java index c9de01ad14d..e69b6b99959 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyFtrl.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyFtrl.java @@ -120,18 +120,16 @@ public static ResourceApplyFtrl create(Scope scope, Operand public static Options useLocking(Boolean useLocking) { return new Options().useLocking(useLocking); } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ResourceApplyFtrlV2"; - + /** * @param multiplyLinearByLr */ public static Options multiplyLinearByLr(Boolean multiplyLinearByLr) { return new Options().multiplyLinearByLr(multiplyLinearByLr); } - + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ResourceApplyFtrlV2"; private ResourceApplyFtrl(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyGradientDescent.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyGradientDescent.java index 0e495bdd651..f33c6b9ca87 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyGradientDescent.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyGradientDescent.java @@ -90,7 +90,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyGradientDescent"; - private ResourceApplyGradientDescent(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyKerasMomentum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyKerasMomentum.java index 3a986d617eb..3922439dcad 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyKerasMomentum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyKerasMomentum.java @@ -124,7 +124,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyKerasMomentum"; - private ResourceApplyKerasMomentum(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyMomentum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyMomentum.java index c441193d864..c554c8a939d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyMomentum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyMomentum.java @@ -124,7 +124,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyMomentum"; - private ResourceApplyMomentum(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyPowerSign.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyPowerSign.java index c1ba8b0ebd7..662c2253264 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyPowerSign.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyPowerSign.java @@ -104,7 +104,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyPowerSign"; - private ResourceApplyPowerSign(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalAdagrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalAdagrad.java index b51ce4698e1..8036d891e33 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalAdagrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalAdagrad.java @@ -100,7 +100,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyProximalAdagrad"; - private ResourceApplyProximalAdagrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalGradientDescent.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalGradientDescent.java index 7f9c4f4e52c..3b217c88c67 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalGradientDescent.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyProximalGradientDescent.java @@ -97,7 +97,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyProximalGradientDescent"; - private ResourceApplyProximalGradientDescent(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyRmsProp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyRmsProp.java index 4c400f1aaa1..ae42295c1f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyRmsProp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceApplyRmsProp.java @@ -113,7 +113,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceApplyRMSProp"; - private ResourceApplyRmsProp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdadelta.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdadelta.java index 9a50137d196..baea98fc1f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdadelta.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdadelta.java @@ -101,7 +101,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyAdadelta"; - private ResourceSparseApplyAdadelta(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagrad.java index 69a3e775622..f7816e78d0c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagrad.java @@ -120,7 +120,6 @@ public static Options updateSlots(Boolean updateSlots) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyAdagrad"; - private ResourceSparseApplyAdagrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradDa.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradDa.java index 4f3189f074c..417eca86a80 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradDa.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradDa.java @@ -104,7 +104,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyAdagradDA"; - private ResourceSparseApplyAdagradDa(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradV2.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradV2.java index 30e6c19da15..f60d192c368 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradV2.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyAdagradV2.java @@ -121,7 +121,6 @@ public static Options updateSlots(Boolean updateSlots) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyAdagradV2"; - private ResourceSparseApplyAdagradV2(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyCenteredRmsProp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyCenteredRmsProp.java index ab9b9d3c38d..d6806c36abf 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyCenteredRmsProp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyCenteredRmsProp.java @@ -124,7 +124,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyCenteredRMSProp"; - private ResourceSparseApplyCenteredRmsProp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyFtrl.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyFtrl.java index 84caa503dec..a13382272c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyFtrl.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyFtrl.java @@ -125,18 +125,15 @@ public static Options useLocking(Boolean useLocking) { return new Options().useLocking(useLocking); } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "ResourceSparseApplyFtrlV2"; - /** * @param multiplyLinearByLr */ public static Options multiplyLinearByLr(Boolean multiplyLinearByLr) { return new Options().multiplyLinearByLr(multiplyLinearByLr); } - + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "ResourceSparseApplyFtrlV2"; private ResourceSparseApplyFtrl(Operation operation) { super(operation); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyKerasMomentum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyKerasMomentum.java index 0284564f78c..b385403f989 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyKerasMomentum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyKerasMomentum.java @@ -129,7 +129,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyKerasMomentum"; - private ResourceSparseApplyKerasMomentum(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyMomentum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyMomentum.java index 5199932b5bc..bc303bfbbf0 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyMomentum.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyMomentum.java @@ -129,7 +129,6 @@ public static Options useNesterov(Boolean useNesterov) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyMomentum"; - private ResourceSparseApplyMomentum(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalAdagrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalAdagrad.java index e235a19f5d1..678601d6aea 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalAdagrad.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalAdagrad.java @@ -105,7 +105,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyProximalAdagrad"; - private ResourceSparseApplyProximalAdagrad(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalGradientDescent.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalGradientDescent.java index 08a9edc01c4..11ad213524c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalGradientDescent.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyProximalGradientDescent.java @@ -101,7 +101,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyProximalGradientDescent"; - private ResourceSparseApplyProximalGradientDescent(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyRmsProp.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyRmsProp.java index 982e1f30eb7..8c519504f89 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyRmsProp.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/ResourceSparseApplyRmsProp.java @@ -116,7 +116,6 @@ public static Options useLocking(Boolean useLocking) { /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "ResourceSparseApplyRMSProp"; - private ResourceSparseApplyRmsProp(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Save.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Save.java index 781714d8121..c5de40fc91b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Save.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/Save.java @@ -63,7 +63,6 @@ public static Save create(Scope scope, Operand prefix, Operand /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SaveV2"; - private Save(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SaveSlices.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SaveSlices.java index e8e67190e63..73325d1d1bc 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SaveSlices.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SaveSlices.java @@ -89,7 +89,6 @@ public static SaveSlices create(Scope scope, Operand filename, Operand< /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SaveSlices"; - private SaveSlices(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SdcaShrinkL1.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SdcaShrinkL1.java index 24c6d53ef9d..748a2eacaec 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SdcaShrinkL1.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/SdcaShrinkL1.java @@ -56,7 +56,6 @@ public static SdcaShrinkL1 create(Scope scope, Iterable> weigh /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "SdcaShrinkL1"; - private SdcaShrinkL1(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Send.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Send.java index d1172f8e96f..b18a86458ca 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Send.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/Send.java @@ -55,7 +55,6 @@ public static Send create(Scope scope, Operand tensor, Stri /** The name of this op, as known by TensorFlow core engine */ public static final String OP_NAME = "XlaSend"; - private Send(Operation operation) { super(operation); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/NN.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/NN.java deleted file mode 100644 index b4fa7bd01de..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/NN.java +++ /dev/null @@ -1,379 +0,0 @@ -package org.tensorflow.op.core; - -import org.tensorflow.DataType; -import org.tensorflow.Operand; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.math.*; -import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits; -import org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits; -import org.tensorflow.types.*; -import org.tensorflow.types.family.TNumber; -import org.tensorflow.op.dtypes.Cast; -import org.tensorflow.types.family.TType; -import org.tensorflow.op.linalg.Transpose; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -@Operator(group = "nn") -public abstract class NN { - - /** - * Computes softmax cross entropy between `logits` and `labels`. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

**NOTE:** While the classes are mutually exclusive, their probabilities need not be. All - * that is required is that each row of `labels` is a valid probability distribution. If they are - * not, the computation of the gradient will be incorrect. - * - *

If using exclusive `labels` (wherein one and only one class is true at a time), see - * `sparse_softmax_cross_entropy_with_logits`. - * - *

Usage: - * - *

-   *   >>> logits = [[4.0, 2.0, 1.0], [0.0, 5.0, 1.0]]
-   *   >>> labels = [[1.0, 0.0, 0.0], [0.0, 0.8, 0.2]]
-   *   >>> tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
-   *   
-   * 
- * - *

Backpropagation will happen into both `logits` and `labels`. To disallow backpropagation - * into `labels`, pass label tensors through `tf.stop_gradient` before feeding it to this - * function. - * - * @param scope current scope - * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape `[batch_size, num_classes]`, - * each row of `labels[i]` must be a valid probability distribution. - * @param logits Per-label activations, typically a linear output. These activation energies are - * interpreted as unnormalized log probabilities. - * @param axis The class dimension. -1 is the last dimension. - * @param the data type of the logits - * @param the number type of the operands - * @return the softmax cross entropy loss. Its type is the same as `logits` and its shape is the - * same as `labels` except that it does not have the last dimension of `labels`. - */ - @Endpoint(name = "softmaxCrossEntropyWithLogits") - public static Operand softmaxCrossEntropyWithLogits( - Scope scope, Operand labels, Operand logits, int axis) { - axis = axis % logits.asOutput().shape().numDimensions(); - if (axis < 0) { - axis += logits.asOutput().shape().numDimensions(); - } - - Operand precise_logits = - logits; // cannot use generics cause logits of bool gets cast to TFloat32 - - boolean convertToFloat32 = - logits.asOutput().dataType() == TFloat16.DTYPE - || logits.asOutput().dataType() == TBfloat16.DTYPE; - if (convertToFloat32) { - precise_logits = Cast.create(scope, logits, TFloat32.DTYPE); - } - /* cannot use generics on DataType because precis_logits may have been cast. */ - DataType dtype = precise_logits.asOutput().dataType(); - labels = Cast.create(scope, labels, dtype); - Operand inputRank = - Cast.create(scope, Rank.create(scope, precise_logits), TInt64.DTYPE); - Shape shape = logits.asOutput().shape(); - - // Move the dim to the end if dim is not the last dimension. - if (axis != -1 && axis != precise_logits.asOutput().shape().numDimensions() - 1) { - precise_logits = moveDimToEnd(scope, precise_logits, axis, inputRank); - labels = moveDimToEnd(scope, labels, axis, inputRank); - } - - Shape inputShape = precise_logits.asOutput().shape(); - precise_logits = flattenOuterDims(scope, precise_logits); - labels = flattenOuterDims(scope, labels); - SoftmaxCrossEntropyWithLogits smax = - SoftmaxCrossEntropyWithLogits.create(scope, precise_logits, labels); - /* cannot use generic on cost, because cost may be recast later. */ - Operand cost = smax.loss(); - Operand outputShape = - Slice.create( - scope, - Constant.vectorOf(scope, inputShape.asArray()), - Constant.vectorOf(scope, new long[] {0}), - Constant.vectorOf(scope, new long[] {inputShape.numDimensions() - 1})); - cost = Reshape.create(scope, cost, outputShape); - if (scope.env().isGraph() && !shape.hasUnknownDimension()) { - long[] array = shape.asArray(); - long[] newArray = new long[array.length - 1]; - if (axis < 0) { - axis = shape.numDimensions() + axis; - } - for (int i = 0; i < axis; i++) { - newArray[i] = shape.size(i); - } - for (int i = axis + 1; i < shape.numDimensions(); i++) { - newArray[i - 1] = shape.size(i); - } - Shape newShape = Shape.of(newArray); - cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newShape.asArray())); - } - - if (convertToFloat32) { - cost = Cast.create(scope, cost, logits.asOutput().dataType()); - } - return cost; - } - - /** - * Computes sparse softmax cross entropy between `logits` and `labels`. - * - * @param scope current scope - * @param labels `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of `labels` and - * result) and dtype `int32` or `int64`. Each entry in `labels` must be an index in `[0, - * num_classes)`. Other values will raise an exception when this op is run on CPU, and return - * `NaN` for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape `[d_0, d_1, ..., - * d_{r-1}, num_classes]` and dtype `float16`, `float32`, or `float64`. These activation - * energies are interpreted as unnormalized log probabilities. - * @return A `Tensor` of the same shape as `labels` and of the same type as `logits` with the - * softmax cross entropy loss. - */ - @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") - public static Operand sparseSoftmaxCrossEntropyWithLogits( - Scope scope, Operand labels, Operand logits) { - // assert shapeIsCompatible(labels.asOutput().shape(), logits.asOutput().shape()): - // String.format("Shapes %s and %s are incompatible", - // labels.asOutput().shape(), logits.asOutput().shape()); - scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); - /** cannot use generics on precise_logits as it may be recast later */ - Operand precise_logits = logits; - boolean convertToFloat32 = - logits.asOutput().dataType() == TFloat16.DTYPE - || logits.asOutput().dataType() == TBfloat16.DTYPE; - if (convertToFloat32) { - precise_logits = Cast.create(scope, logits, TFloat32.DTYPE); - } - Shape labelsStaticShape = labels.asOutput().shape(); - org.tensorflow.op.core.Shape labelsShape = - org.tensorflow.op.core.Shape.create(scope, labels); - Shape logitsShape = logits.asOutput().shape(); - Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1); - - boolean staticShapesFullyDefined = - !labelsStaticShape.hasUnknownDimension() && !logitsShortened.hasUnknownDimension(); - if (logitsShape.numDimensions() == 0) { - throw new IllegalArgumentException( - String.format("Logits cannot be scalars - received shape %s.", logitsShape)); - } - if (!logitsShape.hasUnknownDimension() - && !labelsStaticShape.hasUnknownDimension() - && labelsStaticShape.numDimensions() != logitsShape.numDimensions() - 1) { - throw new IllegalArgumentException( - String.format( - "Rank mismatch: Rank of labels (received %s) should equal rank of logits minus 1 (received %s).", - labelsStaticShape.toString(), logitsShape.toString())); - } - - if (staticShapesFullyDefined && !labelsStaticShape.equals(logitsShortened)) { - throw new IllegalArgumentException( - String.format( - "Shape mismatch: The shape of labels (received %s) " - + "should equal the shape of logits except for the last " - + "dimension (received %s).", - labelsStaticShape.toString(), logitsShape.toString())); - } - // Check if no reshapes are required. - if (logitsShape.numDimensions() == 2) { - SparseSoftmaxCrossEntropyWithLogits smax = - SparseSoftmaxCrossEntropyWithLogits.create(scope, precise_logits, labels); - Operand loss = smax.loss(); - if (logits.asOutput().dataType() == TFloat16.DTYPE) { - loss = Cast.create(scope, loss, TFloat16.DTYPE); - } - return loss; - } - - List shapeChecks = new ArrayList<>(); - - if (!staticShapesFullyDefined) { - shapeChecks.add( - AssertThat.create( - scope, - Equal.create( - scope, - org.tensorflow.op.core.Shape.create(scope, labels), - Shapes.take( - scope, - org.tensorflow.op.core.Shape.create(scope, logits), - Constant.scalarOf(scope, -1))), - Collections.singletonList( - Constant.scalarOf( - scope, - "Shape mismatch: The shape of labels " - + "should equal the shape of logits except for the last " - + "dimension ")))); - } - - // Reshape logits to 2 dim, labels to 1 dim. - long numClassses = logitsShape.size(logitsShape.numDimensions() - 1); - - precise_logits = - Reshape.create( - scope, precise_logits, Constant.vectorOf(scope, new long[] {-1, numClassses})); - labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1)); - scope.withControlDependencies(shapeChecks); - SparseSoftmaxCrossEntropyWithLogits smax = - SparseSoftmaxCrossEntropyWithLogits.create(scope, precise_logits, labels); - Operand cost = smax.loss(); - cost = Reshape.create(scope, cost, labelsShape); - if (logits.asOutput().dataType() == TFloat16.DTYPE) { - cost = Cast.create(scope, cost, TFloat16.DTYPE); - } - return cost; - } - - /** - * Computes sigmoid cross entropy given `logits`. - * - *

Measures the probability error in discrete classification tasks in which each class is - * independent and not mutually exclusive. For instance, one could perform multilabel - * classification where a picture can contain both an elephant and a dog at the same time. - * - *

For brevity, let `x = logits`, `z = labels`. The logistic loss is - * - *

-   *     z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
-   *     = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
-   *     = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
-   *     = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
-   *     = (1 - z) * x + log(1 + exp(-x))
-   *     = x - x * z + log(1 + exp(-x))
-   * 
- * - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above - * - *

-   *      x - x * z + log(1 + exp(-x))
-   *      = log(exp(x)) - x * z + log(1 + exp(-x))
-   *      = - x * z + log(1 + exp(x))
-   * 
- * - *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent - * formulation - * - *

-   *     max(x, 0) - x * z + log(1 + exp(-abs(x)))
-   * 
- * - *

`logits` and `labels` must have the same type and shape. - * - * @param scope The TensorFlow scope - * @param labels the labels - * @param logits the logits of type float32 or float64 - * @param the type of labels and logits - * @return the component-wise logistic losses. - */ - @Endpoint(name = "sigmoidCrossEntropyWithLogits") - public static Operand sigmoidCrossEntropyWithLogits( - Scope scope, Operand labels, Operand logits) { - if (labels.asOutput().shape().numDimensions() != logits.asOutput().shape().numDimensions()) - throw new IllegalArgumentException( - String.format( - "logits and labels must have the same shape (%s vs %s)", - labels.asOutput().shape().toString(), logits.asOutput().shape())); - Operand zeros = - Cast.create(scope, ZerosLike.create(scope, logits), logits.asOutput().dataType()); - Operand cond = GreaterEqual.create(scope, logits, zeros); - - Operand relu_logits = Select.create(scope, cond, logits, zeros); - Operand neg_abs_logits = Select.create(scope, cond, Neg.create(scope, logits), logits); - return Add.create( - scope, - Sub.create(scope, relu_logits, Mul.create(scope, logits, labels)), - Log1p.create(scope, Exp.create(scope, neg_abs_logits))); - } - - /** - * Flattens logits' outer dimensions and keep its last dimension. - * - * @param scope the TensorFlow scope - * @param logits the logits - * @param the type of logits - * @return the flattened logits - */ - private static Operand flattenOuterDims(Scope scope, Operand logits) { - Operand one = Constant.scalarOf(scope, 1L); - - org.tensorflow.ndarray.Shape shape = logits.asOutput().shape(); - int ndims = shape.numDimensions(); - if (!shape.hasUnknownDimension()) { - long product = 1L; - boolean productValid = true; - for (int i = ndims - 2; i >= 0; i--) { - long d = shape.size(i); - if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) { - productValid = false; - break; - } - product *= d; - } - if (productValid) { - org.tensorflow.ndarray.Shape outputShape = Shape.of(product, shape.size(ndims - 1)); - return Reshape.create(scope, logits, Constant.vectorOf(scope, outputShape.asArray())); - } - } - - Operand rank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); - Operand rankMinusOne = Sub.create(scope, rank, one); - - Operand last_dim_size = - Slice.create( - scope, - org.tensorflow.op.core.Shape.create(scope, logits, TInt64.DTYPE), - rankMinusOne, - one); - Operand concat = - Concat.create( - scope, - Arrays.asList(Constant.vectorOf(scope, new long[] {-1}), last_dim_size), - Constant.scalarOf(scope, 0)); - return Reshape.create(scope, logits, concat); - } - - /** - * Move the dim to the end if dim is not the last dimension. - * - * @param scope The TensorFlow Scope - * @param input the input to reshape - * @param dim_index the index to move - * @param rank the number of Dimensions in the tensor - * @param the data type of the tensor. - * @param the data type of the rank - * @return the reshaped input - */ - private static Operand moveDimToEnd( - Scope scope, Operand input, int dim_index, Operand rank) { - DataType rankDType = rank.asOutput().dataType(); - Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankDType); - List> concatList = - Arrays.asList( - Range.create( - scope, - Cast.create(scope, Constant.scalarOf(scope, dim_index), rankDType), - one, - one), - Range.create( - scope, - Cast.create(scope, Constant.scalarOf(scope, (dim_index + 1)), rankDType), - rank, - one)); - return Transpose.create( - scope, input, Concat.create(scope, concatList, Constant.scalarOf(scope, 0))); - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java new file mode 100644 index 00000000000..4f3e9569103 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java @@ -0,0 +1,108 @@ +package org.tensorflow.op.nn; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.Select; +import org.tensorflow.op.core.ZerosLike; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; + +@Operator(group = "nn") +public class SigmoidCrossEntropyWithLogits { + + /** + * Computes sigmoid cross entropy given logits. + * + *

Measures the probability error in discrete classification tasks in which each class is + * independent and not mutually exclusive. For instance, one could perform multilabel + * classification where a picture can contain both an elephant and a dog at the same time. + * + *

For brevity, let x = logits, z = labels. The logistic loss in + * pseudo-code is + * + *

+   * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+   *  = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
+   *  = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
+   *  = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
+   *  = (1 - z) * x + log(1 + exp(-x))
+   *  = x - x * z + log(1 + exp(-x))
+   * 
+ * + *

For x < 0, to avoid overflow in exp(-x), we reformulate the above + * + *

+   * x - x * z + log(1 + exp(-x))
+   *  = log(exp(x)) - x * z + log(1 + exp(-x))
+   *  = - x * z + log(1 + exp(x))
+   * 
+ * + *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent + * formulation + * + *

+   *   max(x, 0) - x * z + log(1 + exp(-abs(x)))
+   * 
+ * + *

logits and labels must have the same type and shape. + * + *

+ * + * @param scope The TensorFlow scope + * @param labels the labels + * @param logits the logits of type float32 or float64 + * @param the type of labels and logits + * @return the component-wise logistic losses. + * @throws IllegalArgumentException if logits' and labels' do not have the same shape + */ + @Endpoint(name = "sigmoidCrossEntropyWithLogits") + public static Operand sigmoidCrossEntropyWithLogits( + Scope scope, Operand labels, Operand logits) { + if (!isCompatible(labels.asOutput().shape(), logits.asOutput().shape())) { + throw new IllegalArgumentException( + String.format( + "logits and labels must have the same shape (%s vs %s)", + labels.asOutput().shape().toString(), logits.asOutput().shape())); + } + scope = scope.withSubScope("SigmoidCrossEntropyWithLogits"); + + Operand zeros = + Cast.create(scope, ZerosLike.create(scope, logits), logits.asOutput().dataType()); + Operand cond = GreaterEqual.create(scope, logits, zeros); + + Operand reluLogits = Select.create(scope, cond, logits, zeros); + Operand negAbsLogits = Select.create(scope, cond, Neg.create(scope, logits), logits); + return Add.create( + scope, + Sub.create(scope, reluLogits, Mul.create(scope, logits, labels)), + Log1p.create(scope, Exp.create(scope, negAbsLogits))); + } + /** + * Determine if 2 shapes are compatible + * + *

2 shapes are compatible if they have the same number of dimensions, and if the corresponding + * dimensions are equal, or at least one of the corresponding dimensions is unknown. + * + * @param shape the first shape + * @param other the second shape + * @return true, if the shapes are compatible. + */ + private static boolean isCompatible(Shape shape, Shape other) { + if (shape.numDimensions() != other.numDimensions()) return false; + for (int i = 0; i < shape.numDimensions(); i++) { + long aShapeDim = shape.size(i); + long bShapeDim = other.size(i); + if (aShapeDim == bShapeDim + || (aShapeDim == Shape.UNKNOWN_SIZE || bShapeDim == Shape.UNKNOWN_SIZE)) { + continue; + } + return false; + } + return true; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java new file mode 100644 index 00000000000..0c8bac697ed --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -0,0 +1,214 @@ +package org.tensorflow.op.nn; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.*; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.linalg.Transpose; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.List; + +@Operator(group = "nn") +public class SoftmaxCrossEntropyWithLogits { + + /** + * Computes softmax cross entropy between logits and labels. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of labels is a valid probability distribution. If they + * are not, the computation of the gradient will be incorrect. + * + *

If using exclusive labels (wherein one and only one class is true at a time), + * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} + * + *

Usage: + * + *

+   *   Operand<TFloat32> logits =
+   *       tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
+   *   Operand<TFloat32> labels =
+   *       tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
+   *   Operand<TFloat32> output =
+   *       tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+   *   // output Shape = [2]
+   *   // dataType = FLOAT (1)
+   *   // values { 0.169846, 0.824745 }
+   * 
+ * + *

Backpropagation will happen into both logits and labels. To + * disallow backpropagation into labels, pass label tensors through + * tf.stopGradient before feeding it to this function. + * + * @param scope current scope + * @param labels Each vector along the class dimension should hold a valid probability + * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] + * , each row of labels[i] must be a valid probability distribution. + * @param logits Per-label activations, typically a linear output. These activation energies are + * interpreted as unnormalized log probabilities. + * @param axis The class dimension. -1 is the last dimension. + * @param the number type of the operands + * @return the softmax cross entropy loss. Its type is the same as logits and its + * shape is the same as labels except that it does not have the last dimension of + * labels. + */ + @Endpoint(name = "softmaxCrossEntropyWithLogits") + public static Operand softmaxCrossEntropyWithLogits( + Scope scope, Operand labels, Operand logits, int axis) { + scope = scope.withSubScope("SoftmaxCrossEntropyWithLogits"); + axis = axis % logits.asOutput().shape().numDimensions(); + if (axis < 0) { + axis += logits.asOutput().shape().numDimensions(); + } + + + boolean convertToFloat32 = + logits.asOutput().dataType() == TFloat16.DTYPE + || logits.asOutput().dataType() == TBfloat16.DTYPE; + if (convertToFloat32) { + Operand result = softmaxCrossEntropyWithLogits(scope, + Cast.create(scope, labels, TFloat32.DTYPE), + Cast.create(scope, logits, TFloat32.DTYPE), + axis); + return Cast.create(scope, result, logits.asOutput().dataType()); + } else if(!logits.asOutput().dataType().equals(labels.asOutput().dataType())) { + return softmaxCrossEntropyWithLogits(scope, + Cast.create(scope, labels, logits.asOutput().dataType()), + logits, + axis); + } + + Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); + Shape shape = logits.asOutput().shape(); + + // Move the dim to the end if dim is not the last dimension. + if (axis != -1 && axis != logits.asOutput().shape().numDimensions() - 1) { + logits = moveDimToEnd(scope, logits, axis, inputRank); + labels = moveDimToEnd(scope, labels, axis, inputRank); + } + + Shape inputShape = logits.asOutput().shape(); + logits = flattenOuterDims(scope, logits); + labels = flattenOuterDims(scope, labels); + + org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits.create( + scope, logits, (Operand)labels); + /* cannot use generic on cost, because cost may be recast later. */ + Operand cost = smax.loss(); + Operand outputShape = + Slice.create( + scope, + Constant.tensorOf(scope, inputShape), + Constant.arrayOf(scope, 0L), + Constant.arrayOf(scope, inputShape.numDimensions() - 1L)); + cost = Reshape.create(scope, cost, outputShape); + if (scope.env().isGraph() && !shape.hasUnknownDimension()) { + long[] array = shape.asArray(); + long[] newArray = new long[array.length - 1]; + if (axis < 0) { + axis = shape.numDimensions() + axis; + } + for (int i = 0; i < axis; i++) { + newArray[i] = shape.size(i); + } + for (int i = axis + 1; i < shape.numDimensions(); i++) { + newArray[i - 1] = shape.size(i); + } + cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray)); + } + + return cost; + } + + /** + * Flattens logits' outer dimensions and keep its last dimension. + * + * @param scope the TensorFlow scope + * @param logits the logits + * @param the type of logits + * @return the flattened logits + */ + private static Operand flattenOuterDims(Scope scope, Operand logits) { + Operand one = Constant.scalarOf(scope, 1L); + + Shape shape = logits.asOutput().shape(); + int ndims = shape.numDimensions(); + if (!shape.hasUnknownDimension()) { + long product = 1L; + boolean productValid = true; + for (int i = ndims - 2; i >= 0; i--) { + long d = shape.size(i); + if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) { + productValid = false; + break; + } + product *= d; + } + if (productValid) { + return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.size(-1))); + } + } + + Operand rank = Cast.create(scope, Rank.create(scope, logits), TInt64.DTYPE); + Operand rankMinusOne = Sub.create(scope, rank, one); + + Operand lastDimSize = + Slice.create( + scope, + org.tensorflow.op.core.Shape.create(scope, logits, TInt64.DTYPE), + rankMinusOne, + one); + Operand concat = + Concat.create( + scope, + Arrays.asList(Constant.arrayOf(scope, -1L), lastDimSize), + Constant.scalarOf(scope, 0)); + return Reshape.create(scope, logits, concat); + } + + /** + * Move the dim to the end if dimIndex is not the last dimension. + * + * @param scope The TensorFlow Scope + * @param input the input to reshape + * @param dimIndex the index to move + * @param rank the number of Dimensions in the tensor + * @param the data type of the tensor. + * @param the data type of the rank + * @return the reshaped input + */ + private static Operand moveDimToEnd( + Scope scope, Operand input, int dimIndex, Operand rank) { + DataType rankDType = rank.asOutput().dataType(); + Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankDType); + List> concatList = + Arrays.asList( + Range.create( + scope, Cast.create(scope, Constant.scalarOf(scope, dimIndex), rankDType), one, one), + Range.create( + scope, + Cast.create(scope, Constant.scalarOf(scope, dimIndex + 1), rankDType), + one, + one)); + return Transpose.create( + scope, input, Concat.create(scope, concatList, Constant.scalarOf(scope, 0))); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java new file mode 100644 index 00000000000..ebd6f74e7d8 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -0,0 +1,161 @@ +package org.tensorflow.op.nn; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.AssertThat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Shapes; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.Equal; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +@Operator(group = "nn") +public class SparseSoftmaxCrossEntropyWithLogits { + + /** + * Computes sparse softmax cross entropy between logits and labels. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

For this operation, the probability of a given label is considered exclusive. That is, soft + * classes are not allowed, and the labels vector must provide a single specific + * index for the true class for each row of logits (each minibatch entry). For soft + * softmax classification with a probability distribution for each entry, {@link + * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. + * + *

WARNING: + * + *

This op expects unscaled logits, since it performs a softmax on logits + * internally for efficiency. Do not call this op with the output of softmax, + * as it will produce incorrect results. + * + *

A common use case is to have logits of shape [batchSize, numClasses] and have + * labels of shape [batchSize], but higher dimensions are supported, in which case + * the dim-th dimension is assumed to be of size numClasses. + * logits must have the dataType of TFloat16, TFloat32 + * , or TFloat64, and labels must have the dtype of TInt32 + * or TInt64. + * + * @param scope current scope + * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r + * is rank of labels and result) and the dataType is TInt32 + * or TInt64. Each entry in labels must be an index in [0, + * numClasses). Other values will raise an exception when this op is run on CPU, and + * return NaN for corresponding loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., + * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, + * or TFloat64. These activation energies are interpreted as unnormalized log + * probabilities. + * @return A Tensor of the same shape as labels and of the same type as + * logits with the softmax cross entropy loss. + * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank + * of the labels is not equal to the rank of the logits minus one. + */ + @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") + public static Operand sparseSoftmaxCrossEntropyWithLogits( + Scope scope, Operand labels, Operand logits) { + scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); + /** cannot use generics on preciseLogits as it may be recast later */ + Operand preciseLogits = logits; + boolean convertToFloat32 = + logits.asOutput().dataType() == TFloat16.DTYPE + || logits.asOutput().dataType() == TBfloat16.DTYPE; + if (convertToFloat32) { + preciseLogits = Cast.create(scope, logits, TFloat32.DTYPE); + } + Shape labelsStaticShape = labels.asOutput().shape(); + org.tensorflow.op.core.Shape labelsShape = + org.tensorflow.op.core.Shape.create(scope, labels); + Shape logitsShape = logits.asOutput().shape(); + Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1); + + boolean staticShapesFullyDefined = + !labelsStaticShape.hasUnknownDimension() && !logitsShortened.hasUnknownDimension(); + if (logitsShape.numDimensions() == 0) { + throw new IllegalArgumentException( + String.format("Logits cannot be scalars - received shape %s.", logitsShape)); + } + if (!logitsShape.hasUnknownDimension() + && !labelsStaticShape.hasUnknownDimension() + && labelsStaticShape.numDimensions() != logitsShape.numDimensions() - 1) { + throw new IllegalArgumentException( + String.format( + "Rank mismatch: Rank of labels (received %s) should equal rank of logits minus 1 (received %s).", + labelsStaticShape.toString(), logitsShape.toString())); + } + + if (staticShapesFullyDefined && !labelsStaticShape.equals(logitsShortened)) { + throw new IllegalArgumentException( + String.format( + "Shape mismatch: The shape of labels (received %s) " + + "should equal the shape of logits except for the last " + + "dimension (received %s).", + labelsStaticShape.toString(), logitsShape.toString())); + } + // Check if no reshapes are required. + if (logitsShape.numDimensions() == 2) { + org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( + scope, preciseLogits, labels); + Operand loss = smax.loss(); + if (logits.asOutput().dataType() == TFloat16.DTYPE) { + loss = Cast.create(scope, loss, TFloat16.DTYPE); + } + return loss; + } + + List shapeChecks = new ArrayList<>(); + + if (!staticShapesFullyDefined) { + shapeChecks.add( + AssertThat.create( + scope, + Equal.create( + scope, + org.tensorflow.op.core.Shape.create(scope, labels), + Shapes.take( + scope, + org.tensorflow.op.core.Shape.create(scope, logits), + Constant.scalarOf(scope, -1))), + Collections.singletonList( + Constant.scalarOf( + scope, + "Shape mismatch: The shape of labels " + + "should equal the shape of logits except for the last " + + "dimension ")))); + } + + // Reshape logits to 2 dims, labels to 1 dim. + long numClassses = logitsShape.size(-1); + + preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses)); + labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1)); + scope.withControlDependencies(shapeChecks); + org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( + scope, preciseLogits, labels); + Operand cost = smax.loss(); + cost = Reshape.create(scope, cost, labelsShape); + if (logits.asOutput().dataType() == TFloat16.DTYPE) { + cost = Cast.create(scope, cost, TFloat16.DTYPE); + } + return cost; + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java index bac5fb96f87..3cc72101893 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java @@ -17,21 +17,22 @@ package org.tensorflow.types; -import java.util.function.Consumer; import org.tensorflow.DataType; import org.tensorflow.Tensor; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; -import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.buffer.BooleanDataBuffer; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; import org.tensorflow.types.family.TType; +import java.util.function.Consumer; + /** * Boolean tensor type. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java index 0f097a16ddb..6e2e7a7ba56 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java @@ -17,23 +17,24 @@ package org.tensorflow.types; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.util.function.Function; import org.tensorflow.DataType; import org.tensorflow.Tensor; import org.tensorflow.internal.buffer.StringTensorBuffer; import org.tensorflow.internal.buffer.TensorBuffers; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffer; import org.tensorflow.ndarray.buffer.layout.DataLayout; import org.tensorflow.ndarray.buffer.layout.DataLayouts; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.impl.dense.DenseNdArray; import org.tensorflow.types.family.TType; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.function.Function; + /** * String type. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index a058649373a..a099eae53e8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -15,63 +15,98 @@ */ package org.tensorflow.framework.optimizers; -import java.util.Collections; -import java.util.List; -import java.util.Map; - import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.Tensor; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; -import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.List; + /** - * SGD plus momentum, either nesterov or traditional. - *

- * See the paper for details of - * nesterov momentum. + * Stochastic gradient descent plus momentum, either nesterov or traditional. + * + *

See the paper for details + * of nesterov momentum. */ public class Momentum extends Optimizer { - public static final String MOMENTUM = "momentum"; + public static final float LEARNING_RATE_DEFAULT = 0.01F; + public static final float MOMENTUM_DEFAULT = 0.0F; + public static final boolean NESTEROV_DEFAULT = false; - private float learningRate; - private Tensor learningRateTensor; - private final Placeholder learningRatePlaceholder; - private Map, Tensor> feedDict; + public static final String MOMENTUM = "momentum"; private final float momentum; private final boolean useNesterov; + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + */ + public Momentum(Graph graph) { + this(graph, LEARNING_RATE_DEFAULT, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate + */ + public Momentum(Graph graph, float learningRate) { + this(graph, learningRate, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. Default is 0. + */ + public Momentum(Graph graph, float learningRate, float momentum) { + this(graph, learningRate, momentum, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. Default is 0. + * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + */ public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { - super(graph); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + super(graph, learningRate); this.momentum = momentum; this.useNesterov = useNesterov; } - public Momentum(Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { - super(graph, name); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer + * @param learningRate the learning rate + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. Default is 0. + * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + */ + public Momentum( + Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { + super(graph, name, learningRate); this.momentum = momentum; this.useNesterov = useNesterov; } + /** {@inheritDoc} */ @Override protected void createSlots(List> variables) { for (Output v : variables) { @@ -79,64 +114,46 @@ protected void createSlots(List> variables) { } } + /** + * Creates a slot for the momentum variable + * + * @param v the variable + * @param the data type of the variable + */ private void createMomentumSlot(Output v) { - Operand initializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand initializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); createSlot(v.asOutput(), MOMENTUM, initializer); } + /** {@inheritDoc} */ @Override protected Op applyDense(Output gradient, Output variable) { Variable slot = getSlot(variable, MOMENTUM).get(); - return tf.train - .applyMomentum(variable, slot, tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), - gradient, - tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), - ApplyMomentum.useNesterov(useNesterov)); + return tf.train.applyMomentum( + variable, + slot, + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), + gradient, + tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), + ApplyMomentum.useNesterov(useNesterov)); } + /** {@inheritDoc} */ @Override public String toString() { - return "Momentum{" + - "learningRate=" + learningRate + - ", momentum=" + momentum + - ", useNesterov=" + useNesterov + - '}'; + return "Momentum{" + + "learningRate=" + + learningRate + + ", momentum=" + + momentum + + ", useNesterov=" + + useNesterov + + '}'; } + /** {@inheritDoc} */ @Override public String getOptimizerName() { return "Momentum"; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public final void setLearningRate(float learningRate) { - this.learningRate = learningRate; - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - } - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - } - - /** {@inheritDoc} */ - public Map, Tensor> getFeedDict() { - return this.feedDict; - } - - /** {@inheritDoc} */ - @Override - public void close() throws Exception { - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - this.learningRateTensor = null; - } - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java new file mode 100644 index 00000000000..d0228eb8b3a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -0,0 +1,295 @@ +package org.tensorflow.framework.optimizers; + +import org.tensorflow.DataType; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +import java.util.List; +import java.util.Optional; + +/** + * Nadam Optimizer that implements the NAdam algorithm. + * + *

Much like Adam is essentially RMSprop with momentum, Nadam is Adam with Nesterov momentum. + * + * @see Dozat, 2015 + */ +public class Nadam extends Optimizer { + + private static final float DECAY_BASE = 0.96f; + private static final float DECAY = 0.004f; + public static final float LEARNING_RATE_DEFAULT = 0.001f; + public static final float EPSILON_DEFAULT = 1e-8f; + public static final float BETA_ONE_DEFAULT = 0.9f; + public static final float BETA_TWO_DEFAULT = 0.999f; + public static final String FIRST_MOMENT = "m"; + public static final String SECOND_MOMENT = "v"; + public static final String MOMENTUM = "momentum"; + + /** The exponential decay rate for the 1st moment estimates. */ + private final float betaOne; + + /** The exponential decay rate for the exponentially weighted infinity norm. */ + private final float betaTwo; + + /** A small constant for numerical stability. */ + private final float epsilon; + + private Constant epsilonConst; + private Constant betaOneConst; + private Constant betaTwoConst; + + private Variable betaOnePower; + private Variable betaTwoPower; + private Variable momentum; + + private long iterations = 0; + + // private Operand mT; + private Operand mT1; + + private Operand oneMinusBeta1; + private Operand oneMinusBeta2; + private Operand oneMinusMT; + private Operand oneMinusMScheduleNew; + private Operand oneMinusMScheduleNext; + private Operand vTPrimeDenominator; + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + */ + public Nadam(Graph graph) { + this(graph, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate, defaults to 0.001 + */ + public Nadam(Graph graph, float learningRate) { + this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param learningRate the learning rate, defaults to 0.001 + * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default + * is 0.999. + * @param epsilon A small constant for numerical stability. Default is 1e-8. + */ + public Nadam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { + super(graph, learningRate); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer, defaults to "Nadam" + * @param learningRate the learning rate, defaults to 0.001 + */ + public Nadam(Graph graph, String name, float learningRate) { + this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer, defaults to "Nadam" + * @param learningRate the learning rate, defaults to 0.001 + * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default + * is 0.999. + * @param epsilon A small constant for numerical stability. Default is 1e-8. + */ + public Nadam( + Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { + super(graph, name, learningRate); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + + /** {@inheritDoc} */ + @Override + protected void createSlots(List> variables) { + for (Output v : variables) { + createNadamSlot(v.asOutput()); + } + betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); + Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); + ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); + + betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); + Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); + ((Graph) tf.scope().env()).addInitializer(betaTwoPowerInit); + + momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.DTYPE); + Assign momentumInit = tf.assign(momentum, tf.constant(1.0F)); + ((Graph) tf.scope().env()).addInitializer(momentumInit); + } + + /** + * Creates slots for first and second moments and momentum + * + * @param v the variable + * @param the data type or the Variable + */ + private void createNadamSlot(Output v) { + Operand firstMomentInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); + Operand secondMomentInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); + + Operand momentumInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); + createSlot(v.asOutput(), MOMENTUM, momentumInitializer); + } + + /** {@inheritDoc} */ + @Override + protected Optional prepare(String scopeName) { + Constant one = tf.constant(1.0F); + Constant point5 = tf.constant(0.5F); + + betaOneConst = tf.constant(betaOne); + betaTwoConst = tf.constant(betaTwo); + Constant localStepConst = tf.constant(this.iterations + 1); + Constant nextStepConst = tf.constant(this.iterations + 2); + Constant decayConst = tf.constant(DECAY); + Constant decayBaseConst = tf.constant(DECAY_BASE); + epsilonConst = tf.constant(this.epsilon); + + Operand mT = + tf.math.mul( + betaOneConst, + tf.math.sub( + one, + tf.math.mul( + point5, + tf.math.pow( + decayBaseConst, + tf.math.mul(decayConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE)))))); + + mT1 = + tf.math.mul( + betaOneConst, + tf.math.sub( + one, + tf.math.mul( + point5, + tf.math.pow( + decayBaseConst, + tf.math.mul(decayConst, tf.dtypes.cast(nextStepConst, TFloat32.DTYPE)))))); + + Operand mScheduleNew = tf.math.mul(momentum, mT); + + mScheduleNew = tf.assign(momentum, mScheduleNew, Assign.useLocking(true)); + Operand mScheduleNext = tf.math.mul(mScheduleNew, mT1); + + oneMinusBeta1 = tf.math.sub(one, betaOneConst); + oneMinusBeta2 = tf.math.sub(one, betaTwoConst); + oneMinusMT = tf.math.sub(one, mT); + oneMinusMScheduleNew = tf.math.sub(one, mScheduleNew); + oneMinusMScheduleNext = tf.math.sub(one, mScheduleNext); + vTPrimeDenominator = + tf.math.sub(one, tf.math.pow(betaTwoConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE))); + return Optional.empty(); + } + + /** {@inheritDoc} */ + @Override + protected Op applyDense(Output gradient, Output variable) { + DataType dType = gradient.dataType(); + Variable m = getSlot(variable, FIRST_MOMENT).get(); // first Moment + Variable v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment + + // gPrime = grad / coefficients['oneMinusMScheduleNew'] + Operand gPrime = tf.math.div(gradient, tf.dtypes.cast(oneMinusMScheduleNew, dType)); + // mT = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad) + Operand mT = + tf.math.add( + tf.math.mul(tf.dtypes.cast(betaOneConst, dType), m), + tf.math.mul(tf.dtypes.cast(oneMinusBeta1, dType), gradient)); + // mT = state_ops.assign(m, mT, use_locking=self._use_locking) + // update m + mT = tf.assign(m, mT, Assign.useLocking(true)); + + // mTPrime = mT / coefficients['oneMinusMScheduleNext'] + Operand mTPrime = tf.math.div(mT, tf.dtypes.cast(oneMinusMScheduleNext, dType)); + + // vT = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] * + // math_ops.square(grad)) + Operand vT = + tf.math.add( + tf.math.mul(tf.dtypes.cast(betaTwoConst, dType), v), + tf.math.mul(tf.dtypes.cast(oneMinusBeta2, dType), tf.math.square(gradient))); + // vT = state_ops.assign(v, vT, use_locking=self._use_locking) + // update v + vT = tf.assign(v, vT, Assign.useLocking(true)); + + // vTPrime = vT / coefficients['vTPrimeDenominator'] + Operand vTPrime = tf.math.div(vT, tf.dtypes.cast(vTPrimeDenominator, dType)); + + // m_t_bar = (coefficients['oneMinusMT'] * gPrime + coefficients['mT1'] * mTPrime) + Operand m_t_bar = + tf.math.add( + tf.math.mul(tf.dtypes.cast(oneMinusMT, dType), gPrime), + tf.math.mul(tf.dtypes.cast(mT1, dType), mTPrime)); + // varT = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(vTPrime) + + // coefficients['epsilon']) + Operand varT = + tf.math.sub( + variable, + tf.math.div( + tf.math.mul(tf.dtypes.cast(getLearningRateOperand(), dType), m_t_bar), + tf.math.add(tf.math.sqrt(vTPrime), tf.dtypes.cast(epsilonConst, dType)))); + + return tf.assign(variable, varT, Assign.useLocking(true)); + } + + /** + * Gathers up the update operations into a single op that can be used as a run target. + * + *

Adds the betaOne, betaTwo and mu updates to the end of the updates list. + * + * @param updateOperations The update operations. + * @param name The name of the run target. + * @return A NoOp with a control dependency on each update operation. + */ + @Override + protected Op finish(List updateOperations, String name) { + iterations++; // increment the step; + updateOperations.add(tf.assign(betaOnePower, tf.math.mul(betaOnePower, betaOneConst))); + updateOperations.add(tf.assign(betaTwoPower, tf.math.mul(betaTwoPower, betaTwoConst))); + return super.finish(updateOperations, name); + } + + /** {@inheritDoc} */ + @Override + public String getOptimizerName() { + return "Nadam"; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index def464a86ca..8e0471dc0ba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -15,50 +15,47 @@ */ package org.tensorflow.framework.optimizers; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - import org.tensorflow.*; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.NoOp; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -/** - * Base class for gradient optimizers. - */ -public abstract class Optimizer implements AutoCloseable { +import java.util.*; +import java.util.stream.Collectors; + +/** Base class for gradient optimizers. */ +public abstract class Optimizer implements AutoCloseable { public static final String LEARNING_RATE = "learning_rate"; public static final String VARIABLE_V2 = "VariableV2"; - /** - * Global state variables - */ - //TODO make this be used. + public static final float LEARNING_RATE_DEFAULT = 0.001f; + + /** Global state variables */ + // TODO make this be used. protected final List> globals; - /** - * The Graph this optimizer is operating on. - */ + /** The Graph this optimizer is operating on. */ protected final Graph graph; - /** - * The ops builder for the graph. - */ + /** The ops builder for the graph. */ protected final Ops tf; - /** - * Top level map key is the variable name, lower level map key is the slot name. - */ + /** Top level map key is the variable name, lower level map key is the slot name. */ private final Map>> slots; + protected float learningRate; + protected Placeholder learningRatePlaceholder = null; + private Tensor learningRateTensor; + private Map, Tensor> feedMap = null; + /** * Builds an optimizer for the supplied graph. - *

- * Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. + * + *

Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. + * * @param graph The graph to optimize. */ protected Optimizer(Graph graph) { @@ -66,10 +63,28 @@ protected Optimizer(Graph graph) { this.tf = Ops.create(graph).withName(getOptimizerName()); this.slots = new HashMap<>(); this.globals = new ArrayList<>(); + setLearningRate(LEARNING_RATE_DEFAULT); } /** * Builds an optimizer for the supplied graph. + * + *

Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. + * + * @param graph The graph to optimize. + * @param learningRate the learning rate. + */ + protected Optimizer(Graph graph, float learningRate) { + this.graph = graph; + this.tf = Ops.create(graph).withName(getOptimizerName()); + this.slots = new HashMap<>(); + this.globals = new ArrayList<>(); + setLearningRate(learningRate); + } + + /** + * Builds an optimizer for the supplied graph. + * * @param graph The graph to optimize. * @param name The base name for the operations. */ @@ -78,6 +93,22 @@ protected Optimizer(Graph graph, String name) { this.tf = Ops.create(graph).withName(name); this.slots = new HashMap<>(); this.globals = new ArrayList<>(); + setLearningRate(LEARNING_RATE_DEFAULT); + } + + /** + * Builds an optimizer for the supplied graph. + * + * @param graph The graph to optimize. + * @param name The base name for the operations. + * @param learningRate the learning rate. + */ + protected Optimizer(Graph graph, String name, float learningRate) { + this.graph = graph; + this.tf = Ops.create(graph).withName(name); + this.slots = new HashMap<>(); + this.globals = new ArrayList<>(); + setLearningRate(learningRate); } public static String createName(Output variable, String slotName) { @@ -96,11 +127,14 @@ public Op minimize(Operand loss, String name) { public List> computeGradients(Operand loss) { List variables = new ArrayList<>(); - graph.operations().forEachRemaining((Operation op) -> { - if (op.type().equals(VARIABLE_V2)) { - variables.add(op); - } - }); + graph + .operations() + .forEachRemaining( + (Operation op) -> { + if (op.type().equals(VARIABLE_V2)) { + variables.add(op); + } + }); Output[] variableOutputArray = new Output[variables.size()]; for (int i = 0; i < variables.size(); i++) { @@ -123,8 +157,8 @@ public List> computeGradients(Operand loss) { } public Op applyGradients(List> gradsAndVars, String name) { - List> variables = gradsAndVars.stream().map(GradAndVar::getVariable) - .collect(Collectors.toList()); + List> variables = + gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList()); createSlots(variables); @@ -142,7 +176,7 @@ public Op applyGradients(List> gradsAndVars, String /** * Gets the slot associated with the specified variable and slot name. * - * @param var The variable to lookup. + * @param var The variable to lookup. * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ @@ -153,7 +187,7 @@ public Optional> getSlot(Output var, String slo /** * Gets the slot associated with the specified variable and slot name. * - * @param varName The variable to lookup. + * @param varName The variable to lookup. * @param slotName The slot name. * @return The slot or {@link Optional#empty}. */ @@ -163,7 +197,7 @@ private Optional> getSlot(String varName, String s Variable slot = variables.get(varName); if (slot != null) { @SuppressWarnings("unchecked") // This method should only be called when the type is known. - Optional> opt = Optional.of((Variable) slot); + Optional> opt = Optional.of((Variable) slot); return opt; } return Optional.empty(); @@ -175,20 +209,20 @@ private Optional> getSlot(String varName, String s * Creates a slot in the graph for the specified variable with the specified name. Adds the slot's * initializer to the graph's initializers, and the slot to the Optimizer's slot map. * - * @param variable The variable to create the slot for. - * @param slotName The name of the slot. + * @param variable The variable to create the slot for. + * @param slotName The name of the slot. * @param initializer The initializer for the slot. - * @param The type of the variable. + * @param The type of the variable. */ - protected void createSlot(Output variable, String slotName, - Operand initializer) { - Variable slot = tf.withName(createName(variable, slotName)) - .variable(variable.shape(), variable.dataType()); + protected void createSlot( + Output variable, String slotName, Operand initializer) { + Variable slot = + tf.withName(createName(variable, slotName)).variable(variable.shape(), variable.dataType()); Assign slotInit = tf.assign(slot, initializer); graph.addInitializer(slotInit); String varName = variable.op().name(); - Map> variables = slots - .computeIfAbsent(slotName, (k) -> new HashMap<>()); + Map> variables = + slots.computeIfAbsent(slotName, (k) -> new HashMap<>()); variables.put(varName, slot); } @@ -206,8 +240,7 @@ protected Optional prepare(String scopeName) { * * @param variables The variables to create slots for. */ - protected void createSlots(List> variables) { - } + protected void createSlots(List> variables) {} private Op applyDense(GradAndVar gradVarPair) { return applyDense(gradVarPair.getGradient(), gradVarPair.getVariable()); @@ -218,7 +251,7 @@ private Op applyDense(GradAndVar gradVarPair) { * * @param gradient The gradient to use. * @param variable The variable to update. - * @param The type of the variable. + * @param The type of the variable. * @return An operand which applies the desired optimizer update to the variable. */ protected abstract Op applyDense(Output gradient, Output variable); @@ -227,7 +260,7 @@ private Op applyDense(GradAndVar gradVarPair) { * Gathers up the update operations into a single op that can be used as a run target. * * @param updateOperations The update operations. - * @param name The name of the run target. + * @param name The name of the run target. * @return A NoOp with a control dependency on each update operation. */ protected Op finish(List updateOperations, String name) { @@ -238,44 +271,78 @@ protected Op finish(List updateOperations, String name) { } /** - * Name of the optimizer. + * Gets the Name of the optimizer. * * @return The optimizer name. */ public abstract String getOptimizerName(); /** - * Set the learning rate + * Sets the learning rate + * * @param learningRate the learning rate */ - public abstract void setLearningRate(float learningRate); + public final void setLearningRate(float learningRate) { + if (this.learningRatePlaceholder == null) { + this.learningRatePlaceholder = + tf.withSubScope(LEARNING_RATE) + .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); + } + + if (this.learningRate != learningRate) { + if (this.learningRateTensor != null) this.learningRateTensor.close(); + this.learningRate = learningRate; + this.learningRateTensor = TFloat32.scalarOf(this.learningRate); + this.feedMap = Collections.singletonMap(this.learningRatePlaceholder, learningRateTensor); + } + } /** - * Get the learning rate + * Gets the learning rate + * * @return the learning rate */ - public abstract float getLearningRate(); + public float getLearningRate() { + return this.learningRate; + } /** - * Get the Feed Dictionary for the run methods to set the Placeholder values(s) + * Gets the learning rate Operand, used by subclasses in their graph operations * - * @return the current Feed Dictionary for the run methods + * @return the learning rate Operand */ - public abstract Map, Tensor> getFeedDict(); + protected Operand getLearningRateOperand() { + return this.learningRatePlaceholder; + } /** - * Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} + * Gets the Feed Map for the run methods to set the Placeholder value(s). Each entry in the Feed + * Map contains a PlaceHolder and a Tensor with the value + * + * @return the current Feed Map for the run methods, this may be null if an LearningRate as an + * Operand has been set. */ + public Map, Tensor> getFeedMap() { + return this.feedMap; + } + + public void close() { + // close the learningRate Tensor if it exists. + if (this.feedMap != null) { + this.feedMap.get(this.learningRatePlaceholder).close(); + } + } + + /** Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} */ public static class Options { protected String sharedName; - private Options() { - } + private Options() {} /** * @param sharedName If non-empty, this variable is named in the given bucket with this - * shared_name. Otherwise, the node name is used instead. + * shared_name. Otherwise, the node name is used instead. */ public Optimizer.Options sharedName(String sharedName) { this.sharedName = sharedName; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizers.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizers.java new file mode 100644 index 00000000000..8d7f9620984 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizers.java @@ -0,0 +1,41 @@ +package org.tensorflow.framework.optimizers; + +import org.tensorflow.Graph; + +import java.util.function.Function; + +/** Enumerator used to create a new Optimizer with default parameters. */ +public enum Optimizers { + ADADELTA(AdaDelta::new), + ADAGRAD(AdaGrad::new), + ADAGRAD_DA(AdaGradDA::new), + ADAM(Adam::new), + ADAMAX(Adamax::new), + FTRL(Ftrl::new), + NADAM(Nadam::new), + RMSPROP(RMSProp::new), + MOMENTUM(Momentum::new), + GRADIENT_DESCENT(GradientDescent::new); + + private final Function creator; + + /** + * Creates an Optimizers enum + * + * @param creator the lambda function that accepts a Graph argument used to create the default + * Optimizer + */ + Optimizers(Function creator) { + this.creator = creator; + } + + /** + * Creates an Optimizer with default settings. + * + * @param graph the TensorFlow Graph + * @return the Optimizer + */ + public Optimizer createOptimizer(Graph graph) { + return creator.apply(graph); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index 3d28c016de7..face906d682 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -15,79 +15,152 @@ */ package org.tensorflow.framework.optimizers; -import java.util.Collections; -import java.util.List; -import java.util.Map; - import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Output; -import org.tensorflow.Tensor; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; -import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; +import java.util.List; + /** * Optimizer that implements the RMSProp algorithm. - *

- * See the lecture - * notes that is inexplicably the canonical reference. + * + *

The gist of RMSprop is to: + *

  • Maintain a moving (discounted) average of the square of gradients + *
  • Divide the gradient by the root of this average + * + *

    + * + *

    This implementation of RMSprop uses plain momentum, not Nesterov momentum. + * + *

    + * + *

    The centered version additionally maintains a moving average of the gradients, and uses + * that average to estimate the variance. + * + *

    + * + * @see Hinton G, + * et al. 2012, lecture notes that is inexplicably the canonical reference. */ public class RMSProp extends Optimizer { + public static final float LEARNING_RATE_DEFAULT = 0.001f; + public static final float DECAY_DEFAULT = 0.9f; + public static final float MOMENTUM_DEFAULT = 0.0f; + public static final float EPSILON_DEFAULT = 1e-10f; + public static final boolean CENTERED_DEFAULT = false; public static final String RMS = "rms"; public static final String MG = "mg"; // mean gradient? public static final String MOMENTUM = "momentum"; - private float learningRate; - private Tensor learningRateTensor; - private final Placeholder learningRatePlaceholder; - private Map, Tensor> feedDict; private final float decay; private final float momentum; private final float epsilon; private final boolean centered; + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + */ + public RMSProp(Graph graph) { + this( + graph, + LEARNING_RATE_DEFAULT, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); + } + + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param learningRate the learning rate + */ public RMSProp(Graph graph, float learningRate) { - this(graph, learningRate, 0.9f, 0.0f, 1e-10f, false); + this(graph, learningRate, DECAY_DEFAULT, MOMENTUM_DEFAULT, EPSILON_DEFAULT, CENTERED_DEFAULT); } - public RMSProp(Graph graph, float learningRate, float decay, float momentum, float epsilon, + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param learningRate the learning rate + * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. + * @param momentum the acceleration factor, default is 0. + * @param epsilon A small constant for numerical stability + * @param centered If true, gradients are normalized by the estimated variance of the + * gradient; if false>, by the uncentered second moment. Setting this to + * true> may help with training, but is slightly more expensive in terms of computation + * and memory. Defaults to false. + */ + public RMSProp( + Graph graph, + float learningRate, + float decay, + float momentum, + float epsilon, boolean centered) { - super(graph); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - + super(graph, learningRate); this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; this.centered = centered; } + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param learningRate the learning rate + */ public RMSProp(Graph graph, String name, float learningRate) { - this(graph, name, learningRate, 0.9f, 0.0f, 1e-10f, false); + this( + graph, + name, + learningRate, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); } - public RMSProp(Graph graph, String name, float learningRate, float decay, float momentum, float epsilon, + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param learningRate the learning rate + * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. + * @param momentum The acceleration factor, default is 0. + * @param epsilon A small constant for numerical stability + * @param centered If true, gradients are normalized by the estimated variance of the + * gradient; if false>, by the uncentered second moment. Setting this to + * true> may help with training, but is slightly more expensive in terms of computation + * and memory. Defaults to false. + */ + public RMSProp( + Graph graph, + String name, + float learningRate, + float decay, + float momentum, + float epsilon, boolean centered) { - super(graph, name); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE).placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); + super(graph, name, learningRate); this.decay = decay; this.momentum = momentum; this.epsilon = epsilon; this.centered = centered; } + /** {@inheritDoc} */ @Override protected void createSlots(List> variables) { for (Output v : variables) { @@ -95,85 +168,75 @@ protected void createSlots(List> variables) { } } + /** + * Creates the RMSProp Slots for Root Mean Squared (RMS), MOMENTUM, and Mean Gradient (MG) + * + * @param v the variable to install in the slot + * @param the datatype of the variable. + */ private void createRMSPropSlot(Output v) { - Operand rmsInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); + Operand rmsInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); createSlot(v.asOutput(), RMS, rmsInitializer); - Operand momentumInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand momentumInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = tf - .fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); + Operand mgInitializer = + tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); createSlot(v.asOutput(), MG, mgInitializer); } } + /** {@inheritDoc} */ @Override protected Op applyDense(Output gradient, Output variable) { Variable rmsSlot = getSlot(variable, RMS).get(); Variable momentumSlot = getSlot(variable, MOMENTUM).get(); if (centered) { Variable mgSlot = getSlot(variable, MG).get(); - return tf.train.applyCenteredRmsProp(variable, mgSlot, rmsSlot, momentumSlot, - tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), + return tf.train.applyCenteredRmsProp( + variable, + mgSlot, + rmsSlot, + momentumSlot, + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), gradient); } - return tf.train.applyRmsProp(variable, rmsSlot, momentumSlot, - tf.dtypes.cast(learningRatePlaceholder, gradient.dataType()), + return tf.train.applyRmsProp( + variable, + rmsSlot, + momentumSlot, + tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(decay), gradient.dataType()), tf.dtypes.cast(tf.constant(momentum), gradient.dataType()), tf.dtypes.cast(tf.constant(epsilon), gradient.dataType()), gradient); } + /** {@inheritDoc} */ @Override public String toString() { - return "RMSProp{" + - "learningRate=" + learningRate + - ", decay=" + decay + - ", momentum=" + momentum + - ", epsilon=" + epsilon + - ", centered=" + centered + - '}'; + return "RMSProp{" + + "learningRate=" + + learningRate + + ", decay=" + + decay + + ", momentum=" + + momentum + + ", epsilon=" + + epsilon + + ", centered=" + + centered + + '}'; } + /** {@inheritDoc} */ @Override public String getOptimizerName() { return "RMSProp"; } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public final void setLearningRate(float learningRate) { - this.learningRate = learningRate; - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - } - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - } - - /** {@inheritDoc} */ - public Map, Tensor> getFeedDict() { - return this.feedDict; - } - - /** {@inheritDoc} */ - @Override - public void close() throws Exception { - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - this.learningRateTensor = null; - } - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java new file mode 100644 index 00000000000..43f85fa0ff1 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.optimizers.schedules;; + +/** + * A LearningRateSchedule that uses a piecewise constant decay schedule. + *

    + *

    The function computes the piecewise constant + when passed the current optimizer step. This can be useful for changing the + learning rate value across different invocations of optimizer functions. + *

    + *

    Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 + for the next 10000 steps, and 0.1 for any additional steps. + */ +public class PiecewiseConstantDecay implements LearningRateSchedule { + private float[] boundaries; + private float[] values; + + private int lastIndex = 0; + + /** + * Create an PiecewiseConstantDecay + * + * @param boundaries An array of with strictly increasing entries + * @param values An array that specifies the + values for the intervals defined by boundaries. It should have one + more element than boundaries. + * @throws java.lang.IllegalArgumentException if the the length of values does not have 1 more element than boundaries. + */ + public PiecewiseConstantDecay(float[] boundaries, float[] values) { + if(boundaries.length != values.length - 1) { + throw new IllegalArgumentException("The length of boundaries should be 1 less than the length of values"); + } + this.boundaries = boundaries; + this.values = values; + } + + + @Override + public float call(int step) { + if(lastIndex < boundaries.length && step > boundaries[lastIndex]) + lastIndex++; + return values[lastIndex]; + } + +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java new file mode 100644 index 00000000000..0988577c38f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.optimizers.schedules; + +/** + * A LearningRateSchedule that uses a polynomial decay schedule. + * + *

    + * + *

    It is commonly observed that a monotonically decreasing learning rate, whose degree of change + * is carefully chosen, results in a better performing model. This schedule applies a polynomial + * decay function to an optimizer step, given a provided `initial_learning_rate`, to reach an + * `end_learning_rate` in the given `decay_steps`. + * + *

    + * + *

    The schedule is a 1-arg callable that produces a decayed learning rate when passed the current + * optimizer step. This can be useful for changing the learning rate value across different + * invocations of optimizer functions. It is computed as: + * + *

    + *     step = min(step, decay_steps)
    + *     ((initialLearningRate - endLearningRate) *
    + * (1 - step / decaySteps) ^ (power)
    + * ) + endLearningRate
    + * 
    + * + *

    + * + *

    If `cycle` is True then a multiple of `decay_steps` is used, the first one that is bigger than + * `step`. + */ +public class PolynomialDecay implements LearningRateSchedule { + private static final float END_LEARNING_RATE_DEFAULT = 0.0001f; + public static final float POWER_DEFAULT = 1.0f; + public static final boolean CYCLE_DEFAULT = false; + + protected final float initialLearningRate; + protected final float decaySteps; + protected final float endLearningRate; + protected final float power; + protected final boolean cycle; + + /** + * Create a PolynomialDecay + * + * @param initialLearningRate The initial learning rate. + * @param decaySteps How often to apply decay. + */ + public PolynomialDecay(float initialLearningRate, int decaySteps) { + this(initialLearningRate, decaySteps, END_LEARNING_RATE_DEFAULT, POWER_DEFAULT, CYCLE_DEFAULT); + } + + /** + * Create a PolynomialDecay + * + * @param initialLearningRate The initial learning rate. + * @param decaySteps How often to apply decay. + * @param cycle Whether or not it should cycle beyond decay_steps. Default is false. + */ + public PolynomialDecay(float initialLearningRate, int decaySteps, boolean cycle) { + this(initialLearningRate, decaySteps, END_LEARNING_RATE_DEFAULT, POWER_DEFAULT, cycle); + } + + /** + * Create a PolynomialDecay + * + * @param initialLearningRate The initial learning rate. + * @param decaySteps How often to apply decay. + * @param endLearningRate The end learning rate. Default is 0.0001. + */ + public PolynomialDecay(float initialLearningRate, int decaySteps, float endLearningRate) { + this(initialLearningRate, decaySteps, endLearningRate, POWER_DEFAULT, CYCLE_DEFAULT); + } + + /** + * Create a PolynomialDecay + * + * @param initialLearningRate The initial learning rate. + * @param decaySteps How often to apply decay. + * @param endLearningRate The end learning rate. Default is 0.0001. + * @param power The power of the polynomial. Defaults to linear, 1.0. + * @param cycle Whether or not it should cycle beyond decay_steps. Default is false. + */ + public PolynomialDecay( + float initialLearningRate, + int decaySteps, + float endLearningRate, + float power, + boolean cycle) { + this.initialLearningRate = initialLearningRate; + this.decaySteps = decaySteps; + this.endLearningRate = endLearningRate; + this.power = power; + this.cycle = cycle; + } + + @Override + public float call(int step) { + + float lDecaySteps = decaySteps; + float lStep = step; + if (cycle) { + float multipler = step == 0 ? 1.0f : (float) Math.ceil(step / decaySteps); + lDecaySteps = decaySteps * multipler; + } else { + lStep = Math.min(lStep, lDecaySteps); + } + + float p = lStep / lDecaySteps; + + float f = (this.initialLearningRate - this.endLearningRate) * (float) Math.pow(1.0f - p, power); + return f + endLearningRate; + } +} diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java similarity index 58% rename from tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index 1cf20f1b0d2..ce5ad379629 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/SGDTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -12,11 +12,11 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.keras.optimizers; +package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.keras.utils.TestSession; +import org.tensorflow.Graph; +import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -24,25 +24,20 @@ import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.SGD.*; /** Test cases for SGD Optimizer */ -public class SGDTest { +public class MomentumTest { - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - int index; - - public SGDTest() {} + public MomentumTest() {} @BeforeAll public static void setUpClass() {} @@ -56,29 +51,13 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of create method, of class SGD. */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map config = new HashMap<>(); - config.put(NAME_KEY, "Ftrl"); - config.put(LEARNING_RATE_KEY, 2.0F); - config.put(MOMENTUM_KEY, MOMENTUM_DEFAULT); - config.put(NESTEROV_KEY, NESTEROV_DEFAULT); - SGD expResult = new SGD(tf, 2.0F); - SGD result = SGD.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } - } - /** Test of getOptimizerName method, of class SGD. */ @Test public void testGetOptimizerName() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - SGD instance = new SGD(tf); - String expResult = "SGD"; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Graph graph = session.getGraph(); + Momentum instance = new Momentum(graph); + String expResult = "Momentum"; String result = instance.getOptimizerName(); assertEquals(expResult, result); } @@ -86,49 +65,47 @@ public void testGetOptimizerName() { @Test public void testBasic() { - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float learningRate = 3.0F; - float epsilon = 1e-6F; - float epsilon1 = 1e-2F; - - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + Graph graph = session.getGraph(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); - Constant grads0 = tf.constant(grads0_init); - Constant grads1 = tf.constant(grads1_init); + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - SGD instance = new SGD(tf, learningRate); + Momentum instance = new Momentum(graph, learningRate); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); /* initialize the local variables */ session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); - session.run(update, instance.getFeedDict()); // 1 step + session.run(update, instance.getFeedMap()); // 1 step float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; @@ -139,37 +116,34 @@ public void testBasic() { @Test public void testMomentum() { - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float learningRate = 2.0F; float momentum = 0.9F; - float epsilon = 1e-6F; - float epsilon1 = 1e-2F; - - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Momentum instance = new Momentum(session.getGraph(), learningRate, momentum)) { Ops tf = session.getTF(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); - Constant grads0 = tf.constant(grads0_init); - Constant grads1 = tf.constant(grads1_init); + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - SGD instance = new SGD(tf, learningRate, momentum); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); Variable momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get(); @@ -181,14 +155,14 @@ public void testMomentum() { session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); - session.run(update, instance.getFeedDict()); // 1 step + session.run(update, instance.getFeedMap()); // 1 step float[] expectedMomentum0 = {0.1F, 0.1F}; float[] expectedMomentum1 = {0.01F, 0.01F}; @@ -200,57 +174,55 @@ public void testMomentum() { session.evaluate(expectedVar0, var0); session.evaluate(expectedVar1, var1); - session.run(update, instance.getFeedDict()); // step 2 + session.run(update, instance.getFeedMap()); // step 2 - float[] expectedMomentum0_2 = {(0.9f * 0.1f + 0.1f), (0.9f * 0.1f + 0.1f)}; - float[] expectedMomentum1_2 = {(0.9f * 0.01f + 0.01f), (0.9f * 0.01f + 0.01f)}; - session.evaluate(expectedMomentum0_2, momentumSlot0); - session.evaluate(expectedMomentum1_2, momentumSlot1); + float[] expectedMomentum02 = {(0.9f * 0.1f + 0.1f), (0.9f * 0.1f + 0.1f)}; + float[] expectedMomentum12 = {(0.9f * 0.01f + 0.01f), (0.9f * 0.01f + 0.01f)}; + session.evaluate(expectedMomentum02, momentumSlot0); + session.evaluate(expectedMomentum12, momentumSlot1); - float[] expectedVar0_2 = { + float[] expectedVar02 = { 1.0F - (0.1F * 2.0F) - ((0.9F * 0.1F + 0.1F) * 2.0F), 2.0F - (0.1F * 2.0F) - ((0.9F * 0.1F + 0.1F) * 2.0F) }; - float[] expectedVar1_2 = { + float[] expectedVar12 = { 2.98F - ((0.9F * 0.01F + 0.01F) * 2.0F), 3.98F - ((0.9F * 0.01F + 0.01F) * 2.0F) }; - session.evaluate(expectedVar0_2, var0); - session.evaluate(expectedVar1_2, var1); + session.evaluate(expectedVar02, var0); + session.evaluate(expectedVar12, var1); } } @Test public void testWithLearningRateDecay() { int numSteps = 2; - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float learningRate = 3.0F; - float epsilon = 1e-6F; - float epsilon1 = 1e-2F; - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Momentum instance = new Momentum(session.getGraph(), learningRate)) { Ops tf = session.getTF(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); - Constant grads0 = tf.constant(grads0_init); - Constant grads1 = tf.constant(grads1_init); + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - SGD instance = new SGD(tf, learningRate); - Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + Op update = instance.applyGradients(gradsAndVars, "MomentumTest"); Variable momentumSlot0 = instance.getSlot(var0.asOutput(), MOMENTUM).get(); assertEquals(momentumSlot0.asOutput().shape(), var0.asOutput().shape()); @@ -261,12 +233,12 @@ public void testWithLearningRateDecay() { session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + // initialize the accumulators session.run(tf.init()); - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + // make sure the variables were initialized properly + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); float[][] expectedVar0 = { {0.7F, 1.7F}, @@ -283,7 +255,9 @@ public void testWithLearningRateDecay() { {2.966667F, 3.966667F} }; for (int step = 0; step < numSteps; step++) { - session.run(update, instance.getFeedDict()); + assertEquals(learningRate, instance.getLearningRate(), 1e-6); + session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); session.evaluate(expectedVar0[step], var0); session.evaluate(expectedVar1[step], var1); learningRate *= 0.1; diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java similarity index 50% rename from tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index 32d90ea91ed..fcdd1e3ef7c 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -12,13 +12,12 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.keras.optimizers; +package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; import org.tensorflow.Tensor; -import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.keras.utils.ND; -import org.tensorflow.keras.utils.TestSession; +import org.tensorflow.framework.utils.ND; +import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -28,26 +27,21 @@ import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.keras.optimizers.Adamax.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.Nadam.*; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; /** Test cases for Nadam Optimizer */ public class NadamTest { - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private static final int VAR = 0; private static final int M = 1; private static final int V = 2; - int index = 0; float momentum = 1; public NadamTest() {} @@ -64,29 +58,12 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of create method, of class Nadam. */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map config = new HashMap<>(); - config.put(NAME_KEY, "AdaDelta"); - config.put(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - config.put(BETA_ONE_KEY, BETA_ONE_DEFAULT); - config.put(BETA_TWO_KEY, BETA_TWO_DEFAULT); - config.put(EPSILON_KEY, EPSILON_DEFAULT); - AdaDelta expResult = new AdaDelta(tf); - AdaDelta result = AdaDelta.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } - } - /** Test of getOptimizerName method, of class Nadam. */ @Test public void testGetOptimizerName() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Nadam instance = new Nadam(tf); + try (TestSession session = TestSession.createTestSession(tfMode); + Nadam instance = new Nadam(session.getGraph())) { + String expResult = "Nadam"; String result = instance.getOptimizerName(); assertEquals(expResult, result); @@ -99,10 +76,10 @@ public void testBasic() { int numSteps = 3; - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float[] zeros = {0.0F, 0.0F}; float[] ones = {1.0F, 1.0F}; @@ -111,63 +88,64 @@ public void testBasic() { FloatNdArray m1 = NdArrays.vectorOf(zeros); FloatNdArray v1 = NdArrays.vectorOf(zeros); FloatNdArray mcache = NdArrays.vectorOf(ones); - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); - float epsilon = 1e-6f; float epsilon1 = 1e-3F; - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Nadam instance = new Nadam(session.getGraph())) { Ops tf = session.getTF(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); - Constant grads0 = tf.constant(grads0_init); - Constant grads1 = tf.constant(grads1_init); + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); - Nadam instance = new Nadam(tf); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); Op update = instance.applyGradients(gradsAndVars, "AdamTest"); /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") Variable[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") Variable[] secondMomentSlots = new Variable[2]; - firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); - secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get(); assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); - firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get(); assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); - secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); /* initialize the local variables */ session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); session.setEpsilon(epsilon1); - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); try (Tensor result = session @@ -177,19 +155,13 @@ public void testBasic() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> { - assertEquals(1F, f.getFloat(), epsilon1); - }); + result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; for (int step = 0; step < numSteps; step++) { - session.run(update, instance.getFeedDict()); + session.run(update, instance.getFeedMap()); float mut = Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); @@ -203,22 +175,16 @@ public void testBasic() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> { - assertEquals(momentum, f.getFloat(), epsilon1); - }); + result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); - FloatNdArray[] resultsNP = nadam_update_numpy(var0_np, grads0_np, step, m0, v0, mcache); - var0_np = resultsNP[VAR]; + FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache); + var0Np = resultsNP[VAR]; m0 = resultsNP[M]; v0 = resultsNP[V]; - resultsNP = nadam_update_numpy(var1_np, grads1_np, step, m1, v1, mcache); - var1_np = resultsNP[VAR]; + resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache); + var1Np = resultsNP[VAR]; m1 = resultsNP[M]; v1 = resultsNP[V]; @@ -231,8 +197,8 @@ public void testBasic() { session.evaluate(v1, secondMomentSlots[1]); // evaluate var0 and var1 - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); } } } @@ -241,10 +207,10 @@ public void testBasic() { public void testWithLearningRateDecay() { int numSteps = 3; - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.1F}; - float[] grads1_init = {0.01F, 0.01F}; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; float[] zeros = {0.0F, 0.0F}; float[] ones = {1.0F, 1.0F}; @@ -253,117 +219,115 @@ public void testWithLearningRateDecay() { FloatNdArray m1 = NdArrays.vectorOf(zeros); FloatNdArray v1 = NdArrays.vectorOf(zeros); FloatNdArray mcache = NdArrays.vectorOf(ones); - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); - float epsilon = 1e-6f; float epsilon1 = 1e-3F; float learningRate = 0.001F; - try (TestSession session = TestSession.createTestSession(tf_mode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Nadam instance = new Nadam(session.getGraph(), learningRate)) { Ops tf = session.getTF(); - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); - Constant grads0 = tf.constant(grads0_init); - Constant grads1 = tf.constant(grads1_init); - Nadam instance = new Nadam(tf, learningRate); /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); + List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); Op update = instance.applyGradients(gradsAndVars, "AdamTest"); /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") Variable[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") Variable[] secondMomentSlots = new Variable[2]; - firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); - secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get(); assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); - firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get(); assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); - secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); /* initialize the local variables */ session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + // initialize the accumulators session.run(tf.init()); session.setEpsilon(epsilon1); - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); try (Tensor result = - session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0) - .expect(TFloat32.DTYPE)) { + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { result - .data() - .scalars() - .forEach( - f -> { - assertEquals(1F, f.getFloat(), epsilon1); - }); + .data() + .scalars() + .forEach( + f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; for (int step = 0; step < numSteps; step++) { - - session.run(update, instance.getFeedDict()); + assertEquals(learningRate, instance.getLearningRate(), 1e-6f); + session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); float mut = - Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); momentum = momentum * mut; try (Tensor result = - session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0) - .expect(TFloat32.DTYPE)) { + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { result - .data() - .scalars() - .forEach( - f -> { - assertEquals(momentum, f.getFloat(), epsilon1); - }); + .data() + .scalars() + .forEach( + f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); - FloatNdArray[] resultsNP = - nadam_update_numpy(var0_np, grads0_np, step, m0, v0, mcache, learningRate); - var0_np = resultsNP[VAR]; + FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache, learningRate); + var0Np = resultsNP[VAR]; m0 = resultsNP[M]; v0 = resultsNP[V]; - resultsNP = nadam_update_numpy(var1_np, grads1_np, step, m1, v1, mcache, learningRate); - var1_np = resultsNP[VAR]; + resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache, learningRate); + var1Np = resultsNP[VAR]; m1 = resultsNP[M]; v1 = resultsNP[V]; @@ -376,8 +340,8 @@ public void testWithLearningRateDecay() { session.evaluate(v1, secondMomentSlots[1]); // evaluate var0 and var1 - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); learningRate *= 0.9; instance.setLearningRate(learningRate); @@ -385,50 +349,45 @@ public void testWithLearningRateDecay() { } } - private FloatNdArray update_m_cache(FloatNdArray mcache, int t) { - float mu_t = 0.9F * (1.0F - 0.5F * (float) Math.pow(0.96, (0.004 * (t + 1)))); - return ND.mul(mu_t, mcache); - } - private FloatNdArray[] nadam_update_numpy( - FloatNdArray var_np, - FloatNdArray grads_np, - int t, - FloatNdArray m, - FloatNdArray v, - FloatNdArray m_cache) { - return nadam_update_numpy(var_np, grads_np, t, m, v, m_cache, 0.001F); + private FloatNdArray[] nadamUpdateNdArray( + FloatNdArray varNp, + FloatNdArray gradsNp, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray mCache) { + return nadamUpdateNdArray(varNp, gradsNp, t, m, v, mCache, 0.001F); } - - private FloatNdArray[] nadam_update_numpy( - FloatNdArray var_np, - FloatNdArray grads_np, - int t, - FloatNdArray m, - FloatNdArray v, - FloatNdArray m_cache, - float alpha) { + private FloatNdArray[] nadamUpdateNdArray( + FloatNdArray varNp, + FloatNdArray gradsNp, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray mCache, + float alpha) { float beta1 = 0.9F; float beta2 = 0.999F; float epsilon = 1e-8F; - float mu_t = beta1 * (1F - 0.5F * (float) Math.pow(0.96, 0.004 * (t + 1))); - float mu_t_1 = beta1 * (1F - 0.5F * (float) Math.pow(0.96, (0.004 * (t + 2)))); - FloatNdArray m_cache_t_1 = ND.mul(m_cache, mu_t_1); - FloatNdArray g_prime_t = ND.div(grads_np, ND.sub(1.0F, m_cache)); - FloatNdArray m_t = ND.add(ND.mul(beta1, m), ND.mul((1 - beta1), grads_np)); - FloatNdArray v_t = ND.add(ND.mul(beta2, v), ND.mul((1 - beta2), ND.square(grads_np))); - - FloatNdArray m_prime_t = ND.div(m_t, ND.sub(1.F, m_cache_t_1)); - FloatNdArray v_prime_t = ND.div(v_t, 1.F - (float) Math.pow(beta2, t + 1)); - FloatNdArray m_bar_t = ND.add(ND.mul((1 - mu_t), g_prime_t), ND.mul(mu_t_1, m_prime_t)); - FloatNdArray param_t = - ND.sub(var_np, ND.div(ND.mul(alpha, m_bar_t), ND.add(ND.sqrt(v_prime_t), epsilon))); + float muT = beta1 * (1F - 0.5F * (float) Math.pow(0.96, 0.004 * (t + 1))); + float muT1 = beta1 * (1F - 0.5F * (float) Math.pow(0.96, (0.004 * (t + 2)))); + FloatNdArray mCacheT1 = ND.mul(mCache, muT1); + FloatNdArray gPrimeT = ND.div(gradsNp, ND.sub(1.0F, mCache)); + FloatNdArray mT = ND.add(ND.mul(beta1, m), ND.mul((1 - beta1), gradsNp)); + FloatNdArray vT = ND.add(ND.mul(beta2, v), ND.mul((1 - beta2), ND.square(gradsNp))); + + FloatNdArray mPrimeT = ND.div(mT, ND.sub(1.F, mCacheT1)); + FloatNdArray vPrimeT = ND.div(vT, 1.F - (float) Math.pow(beta2, t + 1)); + FloatNdArray mBarT = ND.add(ND.mul((1 - muT), gPrimeT), ND.mul(muT1, mPrimeT)); + FloatNdArray paramT = + ND.sub(varNp, ND.div(ND.mul(alpha, mBarT), ND.add(ND.sqrt(vPrimeT), epsilon))); FloatNdArray[] results = new FloatNdArray[3]; - results[VAR] = param_t; - results[M] = m_t; - results[V] = v_t; + results[VAR] = paramT; + results[M] = mT; + results[V] = vT; return results; } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java new file mode 100644 index 00000000000..a0bf027abab --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/OptimizersTest.java @@ -0,0 +1,134 @@ +package org.tensorflow.framework.optimizers; + +import org.junit.jupiter.api.*; +import org.tensorflow.framework.utils.TestSession; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class OptimizersTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + public OptimizersTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + /** Test ADADELTA enum */ + @Test + public void testADADELTA() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADADELTA.createOptimizer(session.getGraph())) { + String expResult = "Adadelta"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test ADAGRAD enum */ + @Test + public void testADAGRAD() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAGRAD.createOptimizer(session.getGraph())) { + String expResult = "Adagrad"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test ADAGRAD_DA enum */ + @Test + public void testADAGRAD_DA() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAGRAD_DA.createOptimizer(session.getGraph())) { + String expResult = "adagrad-da"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test ADAM enum */ + @Test + public void testADAM() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAM.createOptimizer(session.getGraph())) { + String expResult = "Adam"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test ADAMAX enum */ + @Test + public void testADAMAX() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.ADAMAX.createOptimizer(session.getGraph())) { + String expResult = "Adamax"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test FTRL enum */ + @Test + public void testFTRL() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.FTRL.createOptimizer(session.getGraph())) { + String expResult = "Ftrl"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test NADAM enum */ + @Test + public void testNADAM() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.NADAM.createOptimizer(session.getGraph())) { + String expResult = "Nadam"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test RMSPROP enum */ + @Test + public void testRMSPROP() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.RMSPROP.createOptimizer(session.getGraph())) { + String expResult = "RMSProp"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test MOMENTUM enum */ + @Test + public void testMOMENTUM() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.MOMENTUM.createOptimizer(session.getGraph())) { + String expResult = "Momentum"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } + + /** Test GRADIENT_DESCENT enum */ + @Test + public void testGRADIENT_DESCENT() { + try (TestSession session = TestSession.createTestSession(tfMode); + Optimizer instance = Optimizers.GRADIENT_DESCENT.createOptimizer(session.getGraph())) { + String expResult = "GradientDescent"; + String result = instance.getOptimizerName(); + assertEquals(expResult, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java new file mode 100644 index 00000000000..6d489951c77 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -0,0 +1,450 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.optimizers; + +import org.junit.jupiter.api.*; +import org.tensorflow.framework.utils.ND; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import static org.tensorflow.framework.optimizers.RMSProp.*; + +/** Test cases for RMSProp Optimizer */ +public class RMSPropTest { + final int VAR_T = 0; + final int MG_T = 1; + final int RMS_T = 2; + final int MOM_T = 3; + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + Object[][] testParamValues = { + // learningRate, rho (decay), momentum, epsilon, centered + {0.05F, 0.9F, 0.0F, 1e-3F, true}, + {0.05F, 0.9F, 0.0F, 1e-3F, false}, + {0.1F, 0.9F, 0.0F, 1e-3F, true}, + {0.01F, 0.9F, 0.0F, 1e-5F, true}, + {0.01F, 0.9F, 0.9F, 1e-5F, true} + }; + + public RMSPropTest() {} + + @BeforeAll + public static void setUpClass() {} + + @AfterAll + public static void tearDownClass() {} + + @BeforeEach + public void setUp() {} + + @AfterEach + public void tearDown() {} + + @Test + public void testDense() { + + int numSteps = 3; + + for (Object[] testParamValue : testParamValues) { + // learningRate, rho (decay), momentum, epsilon, centered + float learningRate = (float) (float) testParamValue[0]; + float decay = (float) testParamValue[1]; + float momentum = (float) testParamValue[2]; + float epsilon = (float) testParamValue[3]; + boolean centered = (boolean) testParamValue[4]; + try (TestSession session = TestSession.createTestSession(tfMode); + RMSProp instance = + new RMSProp(session.getGraph(), learningRate, decay, momentum, epsilon, centered)) { + Ops tf = session.getTF(); + + session.setEpsilon(1e-2f); + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.2F}; + + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + Variable mg0 = + centered && instance.getSlot(var0.asOutput(), MG).isPresent() + ? instance.getSlot(var0.asOutput(), MG).get() + : null; + Variable mg1 = + centered && instance.getSlot(var1.asOutput(), MG).isPresent() + ? instance.getSlot(var1.asOutput(), MG).get() + : null; + Variable mom0 = + momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var0.asOutput(), MOMENTUM).get() + : null; + Variable mom1 = + momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var1.asOutput(), MOMENTUM).get() + : null; + Variable rms0 = + instance.getSlot(var0.asOutput(), RMS).isPresent() + ? instance.getSlot(var0.asOutput(), RMS).get() + : null; + Variable rms1 = + instance.getSlot(var1.asOutput(), RMS).isPresent() + ? instance.getSlot(var1.asOutput(), RMS).get() + : null; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0Np = NdArrays.vectorOf(zeros); + FloatNdArray mg1Np = NdArrays.vectorOf(zeros); + FloatNdArray rms0Np = NdArrays.vectorOf(ones); + FloatNdArray rms1Np = NdArrays.vectorOf(ones); + FloatNdArray mom0Np = NdArrays.vectorOf(zeros); + FloatNdArray mom1Np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedMap()); + FloatNdArray[] result0 = + calc( + var0Np, + grads0Np, + mg0Np, + rms0Np, + mom0Np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0Np = result0[VAR_T]; + mg0Np = result0[MG_T]; + rms0Np = result0[RMS_T]; + mom0Np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1Np, + grads1Np, + mg1Np, + rms1Np, + mom1Np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1Np = result1[VAR_T]; + mg1Np = result1[MG_T]; + rms1Np = result1[RMS_T]; + mom1Np = result1[MOM_T]; + + if (centered) { + if (mg0 != null) session.evaluate(mg0Np, mg0); + if (mg1 != null) session.evaluate(mg1Np, mg1); + } + + if (mom0 != null) session.evaluate(mom0Np, mom0); + if (mom1 != null) session.evaluate(mom1Np, mom1); + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + if (rms0 != null) session.evaluate(rms0Np, rms0); + else fail("rms0 is null"); + if (rms1 != null) session.evaluate(rms1Np, rms1); + else fail("rms1 is null"); + + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + + @Test + public void testWithLearningRateDecay() { + int numSteps = 3; + + for (Object[] testParamValue : testParamValues) { + float learningRate = (float) testParamValue[0]; + float decay = (float) testParamValue[1]; + float momentum = (float) testParamValue[2]; + float epsilon = (float) testParamValue[3]; + boolean centered = (boolean) testParamValue[4]; + + try (TestSession session = TestSession.createTestSession(tfMode); + RMSProp instance = + new RMSProp(session.getGraph(), learningRate, decay, momentum, epsilon, centered)) { + Ops tf = session.getTF(); + session.setEpsilon(1e-2f); + float[] var0_init = {1.0F, 2.0F}; + float[] var1_init = {3.0F, 4.0F}; + float[] grads0_init = {0.1F, 0.2F}; + float[] grads1_init = {0.01F, 0.2F}; + + FloatNdArray var0_np = NdArrays.vectorOf(var0_init); + FloatNdArray var1_np = NdArrays.vectorOf(var1_init); + FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); + FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); + + Shape shape0 = Shape.of(var0_init.length); + Shape shape1 = Shape.of(var1_init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); + + Constant grads0 = tf.constant(grads0_init); + Constant grads1 = tf.constant(grads1_init); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + // initialize the accumulators + session.run(tf.init()); + + // make sure the variables were initialized properly + session.evaluate(var0_init, var0); + session.evaluate(var1_init, var1); + + Variable mg0 = + centered && instance.getSlot(var0.asOutput(), MG).isPresent() + ? instance.getSlot(var0.asOutput(), MG).get() + : null; + Variable mg1 = + centered && instance.getSlot(var1.asOutput(), MG).isPresent() + ? instance.getSlot(var1.asOutput(), MG).get() + : null; + Variable mom0 = + momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var0.asOutput(), MOMENTUM).get() + : null; + Variable mom1 = + momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var1.asOutput(), MOMENTUM).get() + : null; + Variable rms0 = + instance.getSlot(var0.asOutput(), RMS).isPresent() + ? instance.getSlot(var0.asOutput(), RMS).get() + : null; + Variable rms1 = + instance.getSlot(var1.asOutput(), RMS).isPresent() + ? instance.getSlot(var1.asOutput(), RMS).get() + : null; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0_np = NdArrays.vectorOf(zeros); + FloatNdArray mg1_np = NdArrays.vectorOf(zeros); + FloatNdArray rms0_np = NdArrays.vectorOf(ones); + FloatNdArray rms1_np = NdArrays.vectorOf(ones); + FloatNdArray mom0_np = NdArrays.vectorOf(zeros); + FloatNdArray mom1_np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + assertEquals(learningRate, instance.getLearningRate(), epsilon); + session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.run(update, instance.getFeedMap()); + FloatNdArray[] result0 = + calc( + var0_np, + grads0_np, + mg0_np, + rms0_np, + mom0_np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0_np = result0[VAR_T]; + mg0_np = result0[MG_T]; + rms0_np = result0[RMS_T]; + mom0_np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1_np, + grads1_np, + mg1_np, + rms1_np, + mom1_np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1_np = result1[VAR_T]; + mg1_np = result1[MG_T]; + rms1_np = result1[RMS_T]; + mom1_np = result1[MOM_T]; + + if (centered) { + if (mg0 != null) session.evaluate(mg0_np, mg0); + else fail("mg0 is null"); + if (mg1 != null) session.evaluate(mg1_np, mg1); + else fail("mg1 is null"); + } + if (momentum > 0.F) { + if (mom0 != null) session.evaluate(mom0_np, mom0); + if (mom1 != null) session.evaluate(mom1_np, mom1); + } + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + if (rms0 != null) session.evaluate(rms0_np, rms0); + else fail("rms0 is null"); + if (rms1 != null) session.evaluate(rms1_np, rms1); + else fail("rms1 is null"); + + session.evaluate(var0_np, var0); + session.evaluate(var1_np, var1); + + learningRate *= 0.9F; + instance.setLearningRate(learningRate); + } + } + } + } + + FloatNdArray[] calc( + FloatNdArray varNp, + FloatNdArray gradNp, + FloatNdArray mgNp, + FloatNdArray rmsNp, + FloatNdArray mom, + float lr, + float decay, + float momentum, + float epsilon, + boolean centered) { + + FloatNdArray[] result = new FloatNdArray[4]; // varT, mgT, rmsT, momT + result[RMS_T] = calcRMS(rmsNp, gradNp, decay); // RMS + + FloatNdArray denomT; + if (centered) { + result[MG_T] = calcMG(mgNp, gradNp, decay); + // rmsT - mgT * mgT + denomT = ND.sub(result[RMS_T], ND.square(result[MG_T])); + } else { + result[MG_T] = mgNp; + denomT = rmsNp; + } + if (momentum > 0.F) { + // momentum * mom + lr * g / (np.sqrt(denomT + epsilon)) + result[MOM_T] = calcMom(momentum, mom, lr, gradNp, denomT, epsilon); + // varT = var - momT + result[VAR_T] = ND.sub(varNp, result[MOM_T]); + } else { + result[MOM_T] = mom; + result[VAR_T] = calcVar(varNp, gradNp, lr, denomT, epsilon); + } + + return result; + } + + private FloatNdArray calcRMS(FloatNdArray rmsNp, FloatNdArray gradNp, float decay) { + // rms * rho + (1 - rho) * g * g + FloatNdArray rmsRho = ND.mul(rmsNp, decay); + FloatNdArray squareG = ND.square(gradNp); + float oneRHO = 1.0F - decay; + FloatNdArray decayG2 = ND.mul(oneRHO, squareG); + return ND.add(rmsRho, decayG2); + } + + private FloatNdArray calcMG(FloatNdArray mgNp, FloatNdArray gradNp, float decay) { + // mgT = mg * rho + (1 - rho) * g + FloatNdArray mgRho = ND.mul(mgNp, decay); + float oneRHO = 1.0F - decay; + FloatNdArray decayG = ND.mul(oneRHO, gradNp); + return ND.add(mgRho, decayG); + } + + private FloatNdArray calcMom( + float momentum, + FloatNdArray mom, + float lr, + FloatNdArray gradNp, + FloatNdArray denomT, + float epsilon) { + // momentum * mom + lr * g / (np.sqrt(denomT + epsilon)) + FloatNdArray moMo = ND.mul(momentum, mom); + FloatNdArray dividend = ND.mul(lr, gradNp); + FloatNdArray divisor = ND.sqrt(ND.add(denomT, epsilon)); + FloatNdArray quotient = ND.div(dividend, divisor); + return ND.add(moMo, quotient); + } + + private FloatNdArray calcVar( + FloatNdArray varNp, FloatNdArray gradNp, float lr, FloatNdArray denomT, float epsilon) { + // var - lr * g / (np.sqrt(denomT) + epsilon) + FloatNdArray dividend = ND.mul(lr, gradNp); + FloatNdArray divisor = ND.add(ND.sqrt(denomT), epsilon); + FloatNdArray quotient = ND.div(dividend, divisor); + return ND.sub(varNp, quotient); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java new file mode 100644 index 00000000000..dac8caa19a3 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java @@ -0,0 +1,16 @@ +package org.tensorflow.framework.optimizers.schedules; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class PiecewiseConstantDecayTest { + + public PiecewiseConstantDecayTest() {} + + @Test + public void testDecay() { + + } + +} \ No newline at end of file diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java new file mode 100644 index 00000000000..a28e56ad7cb --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java @@ -0,0 +1,24 @@ +package org.tensorflow.framework.optimizers.schedules; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class PolynomialDecayTest { + + public PolynomialDecayTest() {} + + @Test + public void testBeginWithCycle() { + float initialLearningRate = 0.1f; + int decaySteps = 10; + float decayRate = 0.96f; + float epsilon = 1e-6f; + PolynomialDecay instance = new PolynomialDecay(initialLearningRate, decaySteps, true); + float expected = initialLearningRate; + float actual = instance.call(0); + assertEquals(expected, actual, epsilon); + + } + +} \ No newline at end of file diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java similarity index 96% rename from tensorflow-keras/src/test/java/org/tensorflow/keras/utils/ND.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index 2855af5af25..0503a41dfc2 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -12,18 +12,20 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.keras.utils; +package org.tensorflow.framework.utils; -import java.util.Arrays; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + // TODO used in the Callbacks, this should be a part of NDArray? + /** NDArray math Utilities */ public class ND { @@ -126,7 +128,7 @@ public static FloatNdArray square(FloatNdArray a) { * @return the resulting array from the add operation */ public static FloatNdArray add(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -176,7 +178,7 @@ public static FloatNdArray add(float scalar, FloatNdArray a) { * @return the resulting array from the subtraction operation */ public static FloatNdArray sub(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -232,9 +234,10 @@ public static FloatNdArray sub(float scalar, FloatNdArray a) { * @return the resulting array from the muliply operation */ public static FloatNdArray mul(FloatNdArray a, FloatNdArray b) { - if(!a.shape().equals(b.shape())) - throw new IllegalArgumentException(String.format( - "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); + if (!a.shape().equals(b.shape())) + throw new IllegalArgumentException( + String.format( + "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); boolean sameSize = a.shape().size() == b.shape().size(); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -289,7 +292,7 @@ public static FloatNdArray mul(float scalar, FloatNdArray a) { * @return the resulting array from the Divide operation */ public static FloatNdArray div(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -309,8 +312,7 @@ public static FloatNdArray div(FloatNdArray a, FloatNdArray b) { * @return the resulting array from the Divide operation */ public static FloatNdArray div(FloatNdArray a, float scalar) { - if(scalar == 0) - throw new IllegalArgumentException("Cannot divide by zero"); + if (scalar == 0) throw new IllegalArgumentException("Cannot divide by zero"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); a.elements(nDims - 1) @@ -348,7 +350,7 @@ public static FloatNdArray div(float scalar, FloatNdArray a) { * @return the array result of the power operation */ public static FloatNdArray pow(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -444,10 +446,10 @@ public static float min(FloatNdArray a) { * @param a the first array * @param a the second array * @return the resulting array with the maximum values between each element of the arrays. - * @throws java.lang.AssertionError if the two arrays are not the same size. + * @throws AssertionError if the two arrays are not the same size. */ public static FloatNdArray max(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); @@ -496,10 +498,10 @@ public static FloatNdArray max(float scalar, FloatNdArray a) { * @param a the first array * @param a the second array * @return the resulting array with the minimum values between each element of the arrays. - * @throws java.lang.AssertionError if the two arrays are not the same size. + * @throws AssertionError if the two arrays are not the same size. */ public static FloatNdArray min(FloatNdArray a, FloatNdArray b) { - if(a.shape().size() != b.shape().size()) + if (a.shape().size() != b.shape().size()) throw new IllegalArgumentException("a and b muse have the same number of dimensions"); FloatNdArray result = NdArrays.ofFloats(a.shape()); int nDims = a.shape().numDimensions(); diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java similarity index 82% rename from tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index cd4b891a039..47c39e820fc 100644 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.keras.utils; +package org.tensorflow.framework.utils; import org.tensorflow.*; import org.tensorflow.ndarray.FloatNdArray; @@ -41,7 +41,7 @@ public abstract class TestSession implements AutoCloseable { /** Enumerate between Eager and Graph Mode */ public enum Mode { EAGER, - GRAPH; + GRAPH } public static TestSession createEagerSession() { @@ -56,10 +56,21 @@ public static TestSession createTestSession(Mode mode) { return mode == Mode.EAGER ? createEagerSession() : createGraphSession(); } + /** + * Initializer any graph initializers, if in Graph mode, for Eager mode, this method does nothing. + */ public void initialize() { // empty } + /** + * Returns the Graph if in Graph mode, or null if in EagerMode + * @return the Graph if in Graph mode, or null if in EagerMode + */ + public Graph getGraph() { + return null; + } + /** * Perform session.run() * @@ -67,7 +78,10 @@ public void initialize() { * * @param op The Operation to run */ - public abstract void run(Op op); + public void run(Op op) { + run(op, null); + } + /** * Perform session.run() @@ -75,10 +89,10 @@ public void initialize() { *

    If in eager mode, this does nothing. * * @param op The Operation to run - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ - public abstract void run(Op op, Map, Tensor> feedDict); + public abstract void run(Op op, Map, Tensor> feedMap); /** * Evaluate the expected results versus the actual results @@ -96,15 +110,15 @@ public void evaluate(Number expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( Number expected, Operand input, - Map, Tensor> feedDict) { - evaluate(new Number[] {expected}, input, feedDict); + Map, Tensor> feedMap) { + evaluate(new Number[] {expected}, input, feedMap); } /** @@ -122,13 +136,12 @@ public void evaluate(Number expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. - * @param the data type for the feedDict entries */ - public void evaluate( - Number expected, Op input, Map, Tensor> feedDict) { - evaluate(new Number[] {expected}, input, feedDict); + public void evaluate( + Number expected, Op input, Map, Tensor> feedMap) { + evaluate(new Number[] {expected}, input, feedMap); } /** @@ -148,16 +161,16 @@ public void evaluate(Number[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type for the input */ public void evaluate( Number[] expected, Op input, - Map, Tensor> feedDict) { + Map, Tensor> feedMap) { Output output = input.op().output(0); - evaluate(expected, output, feedDict); + evaluate(expected, output, feedMap); } /** @@ -177,16 +190,16 @@ public void evaluate(Number[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( Number[] expected, Operand input, - Map, Tensor> feedDict) { - Output output = input.asOutput(); - evaluate(expected, output, feedDict); + Map, Tensor> feedMap) { + Output output = input.asOutput(); + evaluate(expected, output, feedMap); } /** @@ -205,15 +218,15 @@ public void evaluate(byte expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( byte expected, Operand input, - Map, Tensor> feedDict) { - evaluate((double) expected, input, feedDict); + Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -232,15 +245,15 @@ public void evaluate(int expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( int expected, Operand input, - Map, Tensor> feedDict) { - evaluate((double) expected, input, feedDict); + Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -259,15 +272,15 @@ public void evaluate(long expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( long expected, Operand input, - Map, Tensor> feedDict) { - evaluate((double) expected, input, feedDict); + Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -286,15 +299,15 @@ public void evaluate(float expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( float expected, Operand input, - Map, Tensor> feedDict) { - evaluate((double) expected, input, feedDict); + Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -313,14 +326,14 @@ public void evaluate(double expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public abstract void evaluate( double expected, Operand input, - Map, Tensor> feedDict); + Map, Tensor> feedMap); /** * Evaluate the expected results versus the actual results @@ -338,19 +351,19 @@ public void evaluate(byte[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( byte[] expected, Operand input, - Map, Tensor> feedDict) { + Map, Tensor> feedMap) { Byte[] iArray = new Byte[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -369,19 +382,19 @@ public void evaluate(int[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( int[] expected, Operand input, - Map, Tensor> feedDict) { + Map, Tensor> feedMap) { Integer[] iArray = new Integer[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -400,19 +413,19 @@ public void evaluate(long[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( long[] expected, Operand input, - Map, Tensor> feedDict) { + Map, Tensor> feedMap) { Long[] iArray = new Long[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -431,19 +444,19 @@ public void evaluate(float[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( float[] expected, Operand input, - Map, Tensor> feedDict) { + Map, Tensor> feedMap) { Float[] iArray = new Float[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -462,19 +475,19 @@ public void evaluate(double[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( double[] expected, Operand input, - Map, Tensor> feedDict) { + Map, Tensor> feedMap) { Double[] iArray = new Double[expected.length]; for (int i = 0; i < expected.length; i++) { iArray[i] = expected[i]; } - evaluate(iArray, input, feedDict); + evaluate(iArray, input, feedMap); } /** @@ -493,14 +506,14 @@ public void evaluate(Number[] expected, Output input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public abstract void evaluate( Number[] expected, Output input, - Map, Tensor> feedDict); + Map, Tensor> feedMap); /** * Evaluate the expected results versus the actual results @@ -517,14 +530,14 @@ public void evaluate(String expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( String expected, Operand input, - Map, Tensor> feedDict) { - evaluate(new String[] {expected}, input, feedDict); + Map, Tensor> feedMap) { + evaluate(new String[] {expected}, input, feedMap); } /** @@ -542,12 +555,12 @@ public void evaluate(String expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( - String expected, Op input, Map, Tensor> feedDict) { - evaluate(new String[] {expected}, input, feedDict); + String expected, Op input, Map, Tensor> feedMap) { + evaluate(new String[] {expected}, input, feedMap); } /** @@ -565,15 +578,15 @@ public void evaluate(String[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( String[] expected, Op input, - Map, Tensor> feedDict) { - Output output = input.op().output(0); - evaluate(expected, output, feedDict); + Map, Tensor> feedMap) { + Output output = input.op().output(0); + evaluate(expected, output, feedMap); } /** @@ -583,7 +596,7 @@ public void evaluate( * @param input the actual value */ public void evaluate(String[] expected, Operand input) { - Output output = input.asOutput(); + Output output = input.asOutput(); evaluate(expected, output, null); } @@ -592,13 +605,13 @@ public void evaluate(String[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public abstract void evaluate( String[] expected, Output input, - Map, Tensor> feedDict); + Map, Tensor> feedMap); /** * Evaluate the expected results versus the actual results @@ -615,14 +628,14 @@ public void evaluate(Boolean expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( Boolean expected, Operand input, - Map, Tensor> feedDict) { - evaluate(new Boolean[] {expected}, input, feedDict); + Map, Tensor> feedMap) { + evaluate(new Boolean[] {expected}, input, feedMap); } /** @@ -640,12 +653,12 @@ public void evaluate(Boolean expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( - Boolean expected, Op input, Map, Tensor> feedDict) { - evaluate(new Boolean[] {expected}, input, feedDict); + Boolean expected, Op input, Map, Tensor> feedMap) { + evaluate(new Boolean[] {expected}, input, feedMap); } /** @@ -655,7 +668,7 @@ public void evaluate( * @param input the actual value */ public void evaluate(Boolean[] expected, Op input) { - Output output = input.op().output(0); + Output output = input.op().output(0); evaluate(expected, output, null); } @@ -664,15 +677,15 @@ public void evaluate(Boolean[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( Boolean[] expected, Op input, - Map, Tensor> feedDict) { - Output output = input.op().output(0); - evaluate(expected, output, feedDict); + Map, Tensor> feedMap) { + Output output = input.op().output(0); + evaluate(expected, output, feedMap); } /** @@ -682,7 +695,7 @@ public void evaluate( * @param input the actual value */ public void evaluate(Boolean[] expected, Operand input) { - Output output = input.asOutput(); + Output output = input.asOutput(); evaluate(expected, output, null); } @@ -691,15 +704,15 @@ public void evaluate(Boolean[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void evaluate( Boolean[] expected, Operand input, - Map, Tensor> feedDict) { - Output output = input.asOutput(); - evaluate(expected, output, feedDict); + Map, Tensor> feedMap) { + Output output = input.asOutput(); + evaluate(expected, output, feedMap); } /** @@ -717,13 +730,13 @@ public void evaluate(Boolean[] expected, Output input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public abstract void evaluate( Boolean[] expected, Output input, - Map, Tensor> feedDict); + Map, Tensor> feedMap); /** * Evaluate the expected results versus the actual results @@ -741,7 +754,7 @@ public void evaluate(Operand expected, Output input) { * * @param expected the expected value * @param input the actual value - * @param the data type for the feedDict entries + * @param the data type for the feedMap entries */ public void evaluate(Operand expected, Operand input) { evaluate(expected.asOutput(), input.asOutput(), null); @@ -752,14 +765,14 @@ public void evaluate(Operand expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. - * @param the data type for the feedDict entries + * @param the data type for the feedMap entries */ public abstract void evaluate( Output expected, Output input, - Map, Tensor> feedDict); + Map, Tensor> feedMap); /** * Evaluate the expected results versus the actual results @@ -777,15 +790,15 @@ public void evaluate(FloatNdArray expected, Operand input * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public void evaluate( FloatNdArray expected, Operand input, - Map, Tensor> feedDict) { - evaluate(expected, input.asOutput(), feedDict); + Map, Tensor> feedMap) { + evaluate(expected, input.asOutput(), feedMap); } /** @@ -804,14 +817,14 @@ public void evaluate(FloatNdArray expected, Output input) * * @param expected the expected value * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public abstract void evaluate( FloatNdArray expected, Output input, - Map, Tensor> feedDict); + Map, Tensor> feedMap); /** * Evaluate the actual results using a predicate @@ -831,14 +844,14 @@ public void evaluate(Operand input, Predicate pre * @param input the actual value * @param predicate a predicate that accepts a Number as an argument, if the result of the * predicate is false, then the test will fail - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type of the input */ public abstract void evaluate( Output input, Predicate predicate, - Map, Tensor> feedDict); + Map, Tensor> feedMap); /** * Evaluate the actual results using a predicate @@ -865,13 +878,13 @@ public void print(Operand input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. - * @param the data type for the feedDict entries + * @param the data type for the feedMap entries */ public void print( - Operand input, Map, Tensor> feedDict) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput(), feedDict); + Operand input, Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput(), feedMap); } /** @@ -887,11 +900,11 @@ public void print(Op input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ - public void print(Op input, Map, Tensor> feedDict) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedDict); + public void print(Op input, Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedMap); } /** @@ -908,13 +921,13 @@ public void print(Output input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type for the input */ public void print( - Output input, Map, Tensor> feedDict) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedDict); + Output input, Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedMap); } /** @@ -933,15 +946,15 @@ public void print(OutputStream out, Operand input) { * * @param out the output stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. - * @param the data type for the feedDict entries + * @param the data type for the feedMap entries */ public void print( OutputStream out, Operand input, - Map, Tensor> feedDict) { - print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput(), feedDict); + Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput(), feedMap); } /** @@ -959,12 +972,12 @@ public void print(OutputStream out, Op input) { * * @param out the output stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void print( - OutputStream out, Op input, Map, Tensor> feedDict) { - print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0), feedDict); + OutputStream out, Op input, Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0), feedMap); } /** @@ -983,15 +996,15 @@ public void print(OutputStream out, Output input) { * * @param out the output stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type for the input */ public void print( OutputStream out, Output input, - Map, Tensor> feedDict) { - print(new PrintWriter(new OutputStreamWriter(out)), input, feedDict); + Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input, feedMap); } /** @@ -1010,15 +1023,15 @@ public void print(Writer writer, Operand input) { * * @param writer the character stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type for the input */ public void print( Writer writer, Operand input, - Map, Tensor> feedDict) { - print(new PrintWriter(writer), input.asOutput(), feedDict); + Map, Tensor> feedMap) { + print(new PrintWriter(writer), input.asOutput(), feedMap); } /** @@ -1036,12 +1049,12 @@ public void print(Writer writer, Op input) { * * @param writer the character stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. */ public void print( - Writer writer, Op input, Map, Tensor> feedDict) { - print(new PrintWriter(writer), input.op().output(0), feedDict); + Writer writer, Op input, Map, Tensor> feedMap) { + print(new PrintWriter(writer), input.op().output(0), feedMap); } /** @@ -1060,15 +1073,15 @@ public void print(Writer writer, Output input) { * * @param writer the character stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type for the input */ public void print( Writer writer, Output input, - Map, Tensor> feedDict) { - print(new PrintWriter(writer), input, feedDict); + Map, Tensor> feedMap) { + print(new PrintWriter(writer), input, feedMap); } /** @@ -1087,14 +1100,14 @@ public void print(PrintWriter writer, Output input) { * * @param writer the character stream * @param input the actual value - * @param feedDict The dictionary of values to pass to the feed() operation of the runner, + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, * required for placeholders. * @param the data type for the input */ public abstract void print( PrintWriter writer, Output input, - Map, Tensor> feedDict); + Map, Tensor> feedMap); /** * Get the TensorFlow Ops for this test session diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java deleted file mode 100644 index f9f796d7738..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Nadam.java +++ /dev/null @@ -1,429 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the ); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.*; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.Placeholder; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TType; - -import java.util.*; - -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; - -/** Nadam Optimizer that implements the NAdam algorithm. */ -public class Nadam extends org.tensorflow.framework.optimizers.Optimizer - implements OptimizerInterface, AutoCloseable { - - public static final String FIRST_MOMENT = "m"; - public static final String SECOND_MOMENT = "v"; - public static final String MOMENTUM = "momentum"; - - public static final String LEARNING_RATE_KEY = "learning_rate"; - public static final String EPSILON_KEY = "epsilon"; - public static final String BETA_ONE_KEY = "beta_1"; - public static final String BETA_TWO_KEY = "beta_2"; - - public static final float LEARNING_RATE_DEFAULT = 0.001F; - public static final float EPSILON_DEFAULT = 1e-07F; - public static final float BETA_ONE_DEFAULT = 0.9F; - public static final float BETA_TWO_DEFAULT = 0.999F; - - private final Map config = new HashMap<>(); - - private float learningRate; - private Tensor learningRateTensor; - private final Placeholder learningRatePlaceholder; - private Map, Tensor> feedDict; - private final float betaOne; - private final float betaTwo; - private final float epsilon; - private final float decayBase = 0.96F; - private final float decay = 0.004F; - - private long iterations = 0; - - private Constant betaOneConst; - private Constant betaTwoConst; - private Constant localStepConst; - private Constant nextStepConst; - - private Constant decayBaseConst; - private Constant decayConst; - private Constant epsilonConst; - - private Variable betaOnePower; - private Variable betaTwoPower; - private Variable momentum; - - private Operand m_t; - private Operand m_t_1; - private Operand m_schedule_new; - private Operand m_schedule_next; - private Operand one_minus_beta_1; - private Operand one_minus_beta_2; - private Operand one_minus_m_t; - private Operand one_minus_m_schedule_new; - private Operand one_minus_m_schedule_next; - private Operand v_t_prime_denominator; - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - */ - public Nadam(Ops tf) { - this(tf, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param name name for the operations created when applying gradients. Defaults to "Nadam". - */ - public Nadam(Ops tf, String name) { - this(tf, name, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. - */ - public Nadam(Ops tf, float learningRate) { - this(tf, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param name name for the operations created when applying gradients. Defaults to "Adamax". - * @param learningRate The learning rate. - */ - public Nadam(Ops tf, String name, float learningRate) { - this(tf, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. - * @param betaOne The exponential decay rate for the 1st moment estimates. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. - * @param epsilon A small constant for numerical stability. - */ - public Nadam(Ops tf, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(assertGraph(tf)); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE) - .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; - initConfig(learningRate, betaOne, betaTwo, epsilon); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param name name for the operations created when applying gradients. - * @param learningRate The learning rate. - * @param betaOne The exponential decay rate for the 1st moment estimates. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. - * @param epsilon A small constant for numerical stability. - */ - public Nadam( - Ops tf, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(assertGraph(tf), name); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.learningRatePlaceholder = - tf.withSubScope(LEARNING_RATE) - .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; - - initConfig(learningRate, betaOne, betaTwo, epsilon); - } - - /** - * Create an Optimizer that implements the NAdam algorithm. - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize - */ - public static Nadam create(Ops tf, Map config) { - String name = (String) config.get(NAME_KEY); - float learningRate = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - float epsilon = (float) config.getOrDefault(EPSILON_KEY, EPSILON_DEFAULT); - float betaOne = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - float betaTwo = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - if (name == null) { - return new Nadam(tf, learningRate, betaOne, betaTwo, epsilon); - } else { - return new Nadam(tf, name, learningRate, betaOne, betaTwo, epsilon); - } - } - - /** {@inheritDoc} */ - @Override - public Map getConfig() { - return config; - } - - /** {@inheritDoc} */ - @Override - public float getLearningRate() { - return this.learningRate; - } - - /** {@inheritDoc} */ - @Override - public final void setLearningRate(float learningRate) { - this.learningRate = learningRate; - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - } - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.feedDict = Collections.singletonMap(this.learningRatePlaceholder, this.learningRateTensor); - } - - /** - * Get the Feed Dictionary for the run methods to set the Placeholder values(s) - * - * @return the current Feed Dictionary for the run methods - */ - public Map, Tensor> getFeedDict() { - return this.feedDict; - } - - /** {@inheritDoc} */ - @Override - public void close() throws Exception { - if (this.learningRateTensor != null) { - this.learningRateTensor.close(); - this.learningRateTensor = null; - } - } - - /** {@inheritDoc} */ - @Override - protected void createSlots(List> variables) { - for (Output v : variables) { - createNadamSlot(v.asOutput()); - } - betaOnePower = tf.withName("beta1_power").variable(Shape.scalar(), TFloat32.DTYPE); - Assign betaOnePowerInit = tf.assign(betaOnePower, tf.constant(betaOne)); - ((Graph) tf.scope().env()).addInitializer(betaOnePowerInit); - - betaTwoPower = tf.withName("beta2_power").variable(Shape.scalar(), TFloat32.DTYPE); - Assign betaTwoPowerInit = tf.assign(betaTwoPower, tf.constant(betaTwo)); - ((Graph) tf.scope().env()).addInitializer(betaTwoPowerInit); - - momentum = tf.withName("momentum").variable(Shape.scalar(), TFloat32.DTYPE); - Assign momentumInit = tf.assign(momentum, tf.constant(1.0F)); - ((Graph) tf.scope().env()).addInitializer(momentumInit); - } - - /** - * Create slots for first and second momements and momentum - * - * @param v the variable - * @param the data type or the Variable - */ - private void createNadamSlot(Output v) { - Operand firstMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); - createSlot(v.asOutput(), FIRST_MOMENT, firstMomentInitializer); - Operand secondMomentInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.dataType())); - createSlot(v.asOutput(), SECOND_MOMENT, secondMomentInitializer); - - Operand momentumInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.dataType())); - createSlot(v.asOutput(), MOMENTUM, momentumInitializer); - } - - /** {@inheritDoc} */ - @Override - protected Optional prepare(String scopeName) { - Constant one = tf.constant(1.0F); - Constant point5 = tf.constant(0.5F); - - betaOneConst = tf.constant(betaOne); - betaTwoConst = tf.constant(betaTwo); - localStepConst = tf.constant(this.iterations + 1); - nextStepConst = tf.constant(this.iterations + 2); - decayConst = tf.constant(decay); - decayBaseConst = tf.constant(this.decayBase); - epsilonConst = tf.constant(this.epsilon); - - // m_t = beta_1_t * (1. - 0.5 * ( math_ops.pow(decay_base, self._initial_decay * local_step))) - m_t = - tf.math.mul( - betaOneConst, - tf.math.sub( - one, - tf.math.mul( - point5, - tf.math.pow( - decayBaseConst, - tf.math.mul(decayConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE)))))); - // m_t_1 = beta_1_t * (1. - 0.5 * ( math_ops.pow(decay_base, self._initial_decay * next_step))) - m_t_1 = - tf.math.mul( - betaOneConst, - tf.math.sub( - one, - tf.math.mul( - point5, - tf.math.pow( - decayBaseConst, - tf.math.mul(decayConst, tf.dtypes.cast(nextStepConst, TFloat32.DTYPE)))))); - - // m_schedule_new = math_ops.cast(self._m_cache_read, var_dtype) * m_t - m_schedule_new = tf.math.mul(momentum, m_t); - // if var_dtype is self._m_cache.dtype: - // m_schedule_new = array_ops.identity(state_ops.assign( - // self._m_cache, m_schedule_new, use_locking=self._use_locking)) - m_schedule_new = tf.identity(tf.assign(momentum, m_schedule_new, Assign.useLocking(true))); - // m_schedule_next = m_schedule_new * m_t_1 - m_schedule_next = tf.math.mul(m_schedule_new, m_t_1); - - // 1 - beta_1_t - one_minus_beta_1 = tf.math.sub(one, betaOneConst); - // 1 - beta_2_t, - one_minus_beta_2 = tf.math.sub(one, betaTwoConst); - // 1. - m_t, - one_minus_m_t = tf.math.sub(one, m_t); - // 1. - m_schedule_new - one_minus_m_schedule_new = tf.math.sub(one, m_schedule_new); - // 1. - m_schedule_next - one_minus_m_schedule_next = tf.math.sub(one, m_schedule_next); - // 1. - math_ops.pow(beta_2_t, local_step) - v_t_prime_denominator = - tf.math.sub(one, tf.math.pow(betaTwoConst, tf.dtypes.cast(localStepConst, TFloat32.DTYPE))); - return Optional.empty(); - } - - /** {@inheritDoc} */ - @Override - protected Op applyDense(Output gradient, Output variable) { - DataType dType = gradient.dataType(); - Variable m = getSlot(variable, FIRST_MOMENT).get(); // first Moment - Variable v = getSlot(variable, SECOND_MOMENT).get(); // Second Moment - - // g_prime = grad / coefficients['one_minus_m_schedule_new'] - Operand g_prime = tf.math.div(gradient, tf.dtypes.cast(one_minus_m_schedule_new, dType)); - // m_t = (coefficients['beta_1_t'] * m + coefficients['one_minus_beta_1_t'] * grad) - Operand m_t = - tf.math.add( - tf.math.mul(tf.dtypes.cast(betaOneConst, dType), m), - tf.math.mul(tf.dtypes.cast(one_minus_beta_1, dType), gradient)); - // m_t = state_ops.assign(m, m_t, use_locking=self._use_locking) - // update m - m_t = tf.assign(m, m_t, Assign.useLocking(true)); - - // m_t_prime = m_t / coefficients['one_minus_m_schedule_next'] - Operand m_t_prime = tf.math.div(m_t, tf.dtypes.cast(one_minus_m_schedule_next, dType)); - - // v_t = (coefficients['beta_2_t'] * v + coefficients['one_minus_beta_2_t'] * - // math_ops.square(grad)) - Operand v_t = - tf.math.add( - tf.math.mul(tf.dtypes.cast(betaTwoConst, dType), v), - tf.math.mul(tf.dtypes.cast(one_minus_beta_2, dType), tf.math.square(gradient))); - // v_t = state_ops.assign(v, v_t, use_locking=self._use_locking) - // update v - v_t = tf.assign(v, v_t, Assign.useLocking(true)); - - // v_t_prime = v_t / coefficients['v_t_prime_denominator'] - Operand v_t_prime = tf.math.div(v_t, tf.dtypes.cast(v_t_prime_denominator, dType)); - - // m_t_bar = (coefficients['one_minus_m_t'] * g_prime + coefficients['m_t_1'] * m_t_prime) - Operand m_t_bar = - tf.math.add( - tf.math.mul(tf.dtypes.cast(one_minus_m_t, dType), g_prime), - tf.math.mul(tf.dtypes.cast(m_t_1, dType), m_t_prime)); - // var_t = var - coefficients['lr_t'] * m_t_bar / (math_ops.sqrt(v_t_prime) + - // coefficients['epsilon']) - Operand var_t = - tf.math.sub( - variable, - tf.math.div( - tf.math.mul(tf.dtypes.cast(this.learningRatePlaceholder, dType), m_t_bar), - tf.math.add(tf.math.sqrt(v_t_prime), tf.dtypes.cast(epsilonConst, dType)))); - // assign(var, var_t, use_locking=self._use_locking) - return tf.assign(variable, var_t, Assign.useLocking(true)); - } - - /** - * Gathers up the update operations into a single op that can be used as a run target. - * - *

    Adds the betaOne, betaTwo and mu updates to the end of the updates list. - * - * @param updateOperations The update operations. - * @param name The name of the run target. - * @return A NoOp with a control dependency on each update operation. - */ - @Override - protected Op finish(List updateOperations, String name) { - iterations++; // increment the step; - updateOperations.add(tf.assign(betaOnePower, tf.math.mul(betaOnePower, betaOneConst))); - updateOperations.add(tf.assign(betaTwoPower, tf.math.mul(betaTwoPower, betaTwoConst))); - return super.finish(updateOperations, name); - } - - /** {@inheritDoc} */ - @Override - public String getOptimizerName() { - return "Nadam"; - } - - /** - * Sets the config object based on the current state of the Optmizer. - * - * @param learningRate The learning rate. Defaults to 0.001. - * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9. - * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. - * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the - * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm - * 1 of the paper. Defaults to 1e-7. - */ - private void initConfig(float learningRate, float betaOne, float betaTwo, float epsilon) { - config.put(NAME_KEY, this.getOptimizerName()); - config.put(LEARNING_RATE_KEY, learningRate); - config.put(EPSILON_KEY, epsilon); - config.put(BETA_ONE_KEY, betaOne); - config.put(BETA_TWO_KEY, betaTwo); - } -} diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java deleted file mode 100644 index 183c71dd976..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/OptimizerInterface.java +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.Graph; -import org.tensorflow.op.Ops; - -import java.util.Map; - -/** The main Interface for Keras Optimizers */ -public interface OptimizerInterface { - - /** The value for the name key in the Config object */ - String NAME_KEY = "name"; - - /** - * Get a TensorFlow Graph from the Ops. - * - * @param tf the TensorFlow Ops - * @return the graph - * @throws java.lang.IllegalArgumentException if the TensorFlow Ops does not represent Graph mode - */ - static Graph assertGraph(Ops tf) { - if (!tf.scope().env().isGraph()) { - throw new IllegalArgumentException( - "Invalid environment, Optimizers can only be used in Graph Mode"); - } - return (Graph) tf.scope().env(); - } - - /** - * Return the config object used to initialize the Optimizer - * - * @return the config object used to initialize the Optimizer - */ - Map getConfig(); -} diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java deleted file mode 100644 index aecd8dcf537..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/Optimizers.java +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.op.Ops; - -import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; -import java.util.HashMap; -import java.util.Map; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.logging.Level; -import java.util.logging.Logger; - -/** - * Functions to get an Optimizer based on String name, an Optimizer class, or lambda function. - * - *

    Example: - * - *

    - *     Adam instance = Optimizers.get(tf, "adam");
    - *     Ftrl instance = Optimizers.get(tf, ltf -> new Ftrl(ltf, 0.1f);
    - * 
    - */ -public class Optimizers { - - static Map> map = - new HashMap>() { - { - put("adadelta", AdaDelta::new); - put("adagrad", AdaGrad::new); - put("adagrad-da", AdaGradDA::new); - put("adam", Adam::new); - put("adamax", Adamax::new); - put("ftrl", Ftrl::new); - put("nadam", Nadam::new); - put("rmsprop", RMSProp::new); - put("sgd", SGD::new); - } - }; - - /** - * Get an Optimizer - * - * @param optimizerFunction either a String that identifies the Optimizer, an Optimizer class, or - * an Optimizer object. - * @return the Optimizer object or null if not found. - */ - public static Optimizer get(Ops tf, Object optimizerFunction) { - return get(tf, optimizerFunction, null); - } - - /** - * Get an Optimizer - * - * @param func a lamda function that returns the Optimizer - * @return the Intializer object - */ - public static Optimizer get(Ops tf, Function func) { - return func.apply(tf); - } - - /** - * Get an Optimizer - * - * @param optimizerFunction either a String that identifies the Optimizer, an Optimizer class, or - * * an Optimizer object. - * @param custom_functions a map of Optimizer lambdas that will be queried if the Optimizer is not - * found in the standard keys - * @return the Optimizer object - */ - public static Optimizer get( - Ops tf, Object optimizerFunction, Map> custom_functions) { - if (optimizerFunction != null) { - if (optimizerFunction instanceof String) { - String s = - optimizerFunction - .toString(); // do this for Java 8 rather than Pattern Matching for instanceof - Function function = map.get(s); - if (function == null && custom_functions != null) { - function = custom_functions.get(s); - } - return function != null ? function.apply(tf) : null; - } else if (optimizerFunction instanceof Class) { - // do this for Java 8 rather than Pattern Matching for instanceof - Class c = (Class) optimizerFunction; - try { - Constructor ctor = c.getConstructor(Ops.class); - return (Optimizer) ctor.newInstance(tf); - } catch (NoSuchMethodException - | InstantiationException - | IllegalAccessException - | IllegalArgumentException - | InvocationTargetException ex) { - Logger.getLogger(Optimizers.class.getName()).log(Level.SEVERE, null, ex); - } - } else if (optimizerFunction instanceof Optimizer) { - return (Optimizer) optimizerFunction; - } else if (optimizerFunction instanceof Function) { - return ((Function) optimizerFunction).apply(tf); - } else if (optimizerFunction instanceof Supplier) { - return ((Supplier) optimizerFunction).get(); - } - } else { - return null; - } - - throw new IllegalArgumentException( - "optimizerFunction must be a symbolic name, Optimizer, Function, Supplier or a Class object"); - } -} diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java deleted file mode 100644 index 03fc4c01f71..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/RMSProp.java +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the ); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.op.Ops; - -import java.util.HashMap; -import java.util.Map; - -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; - -/** RMSProp Optimizer that implements the RMSProp algorithm. */ -public class RMSProp extends org.tensorflow.framework.optimizers.RMSProp - implements OptimizerInterface { - - public static final String LEARNING_RATE_KEY = "learning_rate"; - public static final String DECAY_KEY = "decay"; - public static final String MOMENTUM_KEY = "momentum"; - public static final String EPSILON_KEY = "epsilon"; - public static final String CENTERED_KEY = "centered"; - - public static final float LEARNING_RATE_DEFAULT = 0.001F; - public static final float DECAY_DEFAULT = 0.9F; - public static final float MOMENTUM_DEFAULT = 0.0F; - public static final float EPSILON_DEFAULT = 1e-07F; - public static final boolean CENTERED_DEFAULT = false; - - private Map config = new HashMap<>(); - - /** - * Create an RMSProp Optimizer with the following defaults, name="RMSProp", learning_rate=0.001, - * decay=0.9, momentum=0.0, epsilon=1e-07, centered=false - * - * @param tf the TensorFlow Ops - */ - public RMSProp(Ops tf) { - this( - tf, - LEARNING_RATE_DEFAULT, - DECAY_DEFAULT, - MOMENTUM_DEFAULT, - EPSILON_DEFAULT, - CENTERED_DEFAULT); - } - - /** - * Create an RMSProp Optimizer with the following defaults, name="RMSProp", decay=0.9, - * momentum=0.0, epsilon=1e-07, centered=false - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. - */ - public RMSProp(Ops tf, float learningRate) { - this(tf, learningRate, DECAY_DEFAULT, MOMENTUM_DEFAULT, EPSILON_DEFAULT, CENTERED_DEFAULT); - } - - /** - * Create an RMSProp Optimizer with the following defaults, decay=0.9, momentum=0.0, - * epsilon=1e-07, centered=false - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients. Defaults to "RMSProp" - * @param learningRate The learning rate. - */ - public RMSProp(Ops tf, String name, float learningRate) { - this( - tf, name, learningRate, DECAY_DEFAULT, MOMENTUM_DEFAULT, EPSILON_DEFAULT, CENTERED_DEFAULT); - } - - /** - * Create an RMSProp Optimizer - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. Defaults to 0.001. - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum hyperparameter that accelerates descent in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param epsilon A small constant for numerical stability. - * @param centered If True, gradients are normalized by the estimated variance of the gradient; if - * False, by the uncentered second moment. - */ - public RMSProp( - Ops tf, float learningRate, float decay, float momentum, float epsilon, boolean centered) { - super(assertGraph(tf), learningRate, decay, momentum, epsilon, centered); - initConfig(learningRate, decay, momentum, epsilon, centered); - } - - /** - * Create an RMSProp Optimizer - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients. Defaults to "RMSProp" - * @param learningRate The learning rate. Defaults to 0.001. - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum hyperparameter that accelerates descent in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param epsilon A small constant for numerical stability. - * @param centered If True, gradients are normalized by the estimated variance of the gradient; if - * False, by the uncentered second moment. - */ - public RMSProp( - Ops tf, - String name, - float learningRate, - float decay, - float momentum, - float epsilon, - boolean centered) { - super(assertGraph(tf), name, learningRate, decay, momentum, epsilon, centered); - initConfig(learningRate, decay, momentum, epsilon, centered); - } - - /** - * Create a RMSProp Optimizer using a configuration - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize the Optimizer, the config object has keys for - * "name", "learning_rate", "decay", "momentum", "epsilon" and "centered". If a key is missing - * the default value is used. - * @return the RMSProp optimizer - */ - public static RMSProp fromConfig(Ops tf, Map config) { - return create(tf, config); - } - - /** - * Create a RMSProp Optimizer using a configuration - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize the Optimizer, the config object has keys for - * "name", "learning_rate", "decay", "momentum", "epsilon" and "centered". If a key is missing - * the default value is used. - * @return the RMSProp optimizer - */ - public static RMSProp create(Ops tf, Map config) { - - String name = (String) config.get(NAME_KEY); - float learningRate = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - float decay = (float) config.getOrDefault(DECAY_KEY, DECAY_DEFAULT); - float momentum = (float) config.getOrDefault(MOMENTUM_KEY, MOMENTUM_DEFAULT); - float epsilon = (float) config.getOrDefault(EPSILON_KEY, EPSILON_DEFAULT); - boolean centered = (boolean) config.getOrDefault(CENTERED_KEY, CENTERED_DEFAULT); - if (name == null) { - return new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); - } else { - return new RMSProp(tf, name, learningRate, decay, momentum, epsilon, centered); - } - } - - /** - * Initialize the configuration based on which constructor is called. - * - * @param learningRate The learning rate. Defaults to 0.001. - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum hyperparameter that accelerates descent in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param epsilon A small constant for numerical stability. - * @param centered If True, gradients are normalized by the estimated variance of the gradient; if - * False, by the uncentered second moment. - */ - private void initConfig( - float learningRate, float decay, float momentum, float epsilon, boolean centered) { - config.put(NAME_KEY, this.getOptimizerName()); - config.put(LEARNING_RATE_KEY, learningRate); - config.put(DECAY_KEY, decay); - config.put(MOMENTUM_KEY, momentum); - config.put(EPSILON_KEY, epsilon); - config.put(CENTERED_KEY, centered); - } - - /** {@inheritDoc} */ - @Override - public Map getConfig() { - return config; - } -} diff --git a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java b/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java deleted file mode 100644 index 5e7155c2ab5..00000000000 --- a/tensorflow-keras/src/main/java/org/tensorflow/keras/optimizers/SGD.java +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the ); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.tensorflow.op.Ops; - -import java.util.HashMap; -import java.util.Map; - -import static org.tensorflow.keras.optimizers.OptimizerInterface.assertGraph; - -/** Stochastic Gradient Descent and momentum optimizer. */ -public class SGD extends org.tensorflow.framework.optimizers.Momentum - implements OptimizerInterface { - - public static final String LEARNING_RATE_KEY = "learning_rate"; - public static final String MOMENTUM_KEY = "momentum"; - public static final String NESTEROV_KEY = "nesterov"; - - public static final float LEARNING_RATE_DEFAULT = 0.01F; - public static final float MOMENTUM_DEFAULT = 0.0F; - public static final boolean NESTEROV_DEFAULT = false; - - private Map config = new HashMap<>(); - - /** - * Create a Stochastic Gradient Descent optimizer using defaults: name="SGD", learning_rate=0.01, - * momentum=0.0, and nesterov=false - * - * @param tf the TensorFlow Ops - */ - public SGD(Ops tf) { - this(tf, LEARNING_RATE_DEFAULT, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); - } - - /** - * Create a Stochastic gradient descent optimizer using defaults: name="SGD", momentum=0.0, and - * nesterov=false - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. Defaults to 0.01. - */ - public SGD(Ops tf, float learningRate) { - this(tf, learningRate, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); - } - - /** - * Create a Stochastic gradient descent optimizer using defaults: name="SGD", and nesterov=false - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - */ - public SGD(Ops tf, float learningRate, float momentum) { - this(tf, learningRate, momentum, NESTEROV_DEFAULT); - } - - /** - * Create a Stochastic gradient descent optimizer using defaults: momentum=0.0, and nesterov=false - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients - * @param learningRate The learning rate. Defaults to 0.01. - */ - public SGD(Ops tf, String name, float learningRate) { - this(tf, name, learningRate, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); - } - - /** - * create a Stochastic gradient descent optimizer using defaults: momentum=0.0, and nesterov=false - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients - * @param learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - */ - public SGD(Ops tf, String name, float learningRate, float momentum) { - this(tf, name, learningRate, momentum, NESTEROV_DEFAULT); - } - - /** - * Create a Stochastic gradient descent optimizer - * - * @param tf the TensorFlow Ops - * @param learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to `false`. - */ - public SGD(Ops tf, float learningRate, float momentum, boolean useNesterov) { - super(assertGraph(tf), learningRate, momentum, useNesterov); - if (momentum < 0 || momentum > 1) - throw new IllegalArgumentException("\"momentum\" must be between [0, 1]."); - initConfig(learningRate, momentum, useNesterov); - } - - /** - * Create a Stochastic gradient descent optimizer - * - * @param tf the TensorFlow Ops - * @param name prefix for the operations created when applying gradients - * @param learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to `false`. - */ - public SGD(Ops tf, String name, float learningRate, float momentum, boolean useNesterov) { - super(assertGraph(tf), name, learningRate, momentum, useNesterov); - if (momentum < 0 || momentum > 1) - throw new IllegalArgumentException("\"momentum\" must be between [0, 1]."); - initConfig(learningRate, momentum, useNesterov); - } - - /** - * Create a Stochastic gradient descent optimizer - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize, the config object has keys for "name", - * "learning_rate", "momentum", and "nesterov". If a key is missing the default value is used. - * @return the Stochastic gradient descent optimizer - */ - public static SGD fromConfig(Ops tf, Map config) { - return create(tf, config); - } - - /** - * Create a Stochastic gradient descent optimizer - * - * @param tf the TensorFlow Ops - * @param config a config object to initialize, the config object has keys for "name", - * "learning_rate", "momentum", and "nesterov". If a key is missing the default value is used. - * @return the Stochastic gradient descent optimizer - */ - public static SGD create(Ops tf, Map config) { - - String name = (String) config.get(NAME_KEY); - float learningRate = (float) config.getOrDefault(LEARNING_RATE_KEY, LEARNING_RATE_DEFAULT); - float momentum = (float) config.getOrDefault(MOMENTUM_KEY, MOMENTUM_DEFAULT); - boolean nesterov = (boolean) config.getOrDefault(NESTEROV_KEY, NESTEROV_DEFAULT); - if (name == null) { - return new SGD(tf, learningRate, momentum, nesterov); - } else { - return new SGD(tf, name, learningRate, momentum, nesterov); - } - } - - /** - * Initialize the configuration ased on which constructor is called. - * - * @param learningRate learningRate The learning rate. Defaults to 0.01. - * @param momentum hyperparameter that accelerates SGD in the relevant direction and dampens - * oscillations. Must be between [0, 1]. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to `false`. - */ - private void initConfig(float learningRate, float momentum, boolean useNesterov) { - config.put(NAME_KEY, this.getOptimizerName()); - config.put(LEARNING_RATE_KEY, learningRate); - config.put(MOMENTUM_KEY, momentum); - config.put(NESTEROV_KEY, useNesterov); - } - - /** { @inheritDoc } */ - @Override - public Map getConfig() { - return config; - } - - // overide the momentum name to return "SGD" - /** {@inheritDoc} */ - @Override - public String getOptimizerName() { - return "SGD"; - } -} diff --git a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java b/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java deleted file mode 100644 index 7651872643b..00000000000 --- a/tensorflow-keras/src/test/java/org/tensorflow/keras/optimizers/RMSPropTest.java +++ /dev/null @@ -1,444 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.keras.optimizers; - -import org.junit.jupiter.api.*; -import org.tensorflow.framework.optimizers.Optimizer; -import org.tensorflow.keras.utils.ND; -import org.tensorflow.keras.utils.TestSession; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Op; -import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.Variable; -import org.tensorflow.types.TFloat32; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.RMSProp.*; -import static org.tensorflow.keras.optimizers.Ftrl.LEARNING_RATE_KEY; -import static org.tensorflow.keras.optimizers.OptimizerInterface.NAME_KEY; -import static org.tensorflow.keras.optimizers.RMSProp.*; - -/** Test cases for RMSProp Optimizer */ -public class RMSPropTest { - - private TestSession.Mode tf_mode = TestSession.Mode.GRAPH; - - final int VAR_T = 0; - final int MG_T = 1; - final int RMS_T = 2; - final int MOM_T = 3; - - int index; - - public RMSPropTest() {} - - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - - /** Test of create method, of class RMSProp. */ - @Test - public void testCreate() { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - Map config = new HashMap<>(); - config.put(NAME_KEY, "Ftrl"); - config.put(LEARNING_RATE_KEY, 2.0F); - config.put(DECAY_KEY, DECAY_DEFAULT); - config.put(MOMENTUM_KEY, MOMENTUM_DEFAULT); - config.put(EPSILON_KEY, EPSILON_DEFAULT); - config.put(CENTERED_KEY, CENTERED_DEFAULT); - Ftrl expResult = new Ftrl(tf, 2.0F); - Ftrl result = Ftrl.create(tf, config); - assertEquals(expResult.getConfig(), result.getConfig()); - } - } - - Object[][] _test_param_values = { - // learning_rate, rho (decay), momentum, epsilon, centered - {0.05F, 0.9F, 0.0F, 1e-3F, true}, - {0.05F, 0.9F, 0.0F, 1e-3F, false}, - {0.1F, 0.9F, 0.0F, 1e-3F, true}, - {0.01F, 0.9F, 0.0F, 1e-5F, true}, - {0.01F, 0.9F, 0.9F, 1e-5F, true} - }; - - @Test - public void testDense() { - - int numSteps = 3; - - for (int run = 0; run < _test_param_values.length; run++) { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - session.setEpsilon(1e-2f); - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.2F}; - float[] grads1_init = {0.01F, 0.2F}; - final float epsilon1 = 1e-2F; - - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); - - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - - Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); - - Constant grads0 = tf.constant(grads0_init); - Constant grads1 = tf.constant(grads1_init); - - // learning_rate, rho (decay), momentum, epsilon, centered - float learningRate = (float) (float) _test_param_values[run][0]; - float decay = (float) _test_param_values[run][1]; - float momentum = (float) _test_param_values[run][2]; - float epsilon = (float) _test_param_values[run][3]; - boolean centered = (boolean) _test_param_values[run][4]; - - RMSProp instance = new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); - - /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - - Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); - - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - /** initialize the accumulators */ - session.run(tf.init()); - - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); - - Variable mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; - Variable mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; - Variable mom0 = - momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; - Variable mom1 = - momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; - Variable rms0 = instance.getSlot(var0.asOutput(), RMS).get(); - Variable rms1 = instance.getSlot(var1.asOutput(), RMS).get(); - - float[] zeros = {0.0F, 0.0F}; - float[] ones = {1.0F, 1.0F}; // temp to match RMSProp - FloatNdArray mg0_np = NdArrays.vectorOf(zeros); - FloatNdArray mg1_np = NdArrays.vectorOf(zeros); - FloatNdArray rms0_np = NdArrays.vectorOf(ones); - FloatNdArray rms1_np = NdArrays.vectorOf(ones); - FloatNdArray mom0_np = NdArrays.vectorOf(zeros); - FloatNdArray mom1_np = NdArrays.vectorOf(zeros); - - for (int i = 0; i < numSteps; i++) { - session.run(update, instance.getFeedDict()); - FloatNdArray[] result0 = - calc( - var0_np, - grads0_np, - mg0_np, - rms0_np, - mom0_np, - learningRate, - decay, - momentum, - epsilon, - centered); - var0_np = result0[VAR_T]; - mg0_np = result0[MG_T]; - rms0_np = result0[RMS_T]; - mom0_np = result0[MOM_T]; - - FloatNdArray[] result1 = - calc( - var1_np, - grads1_np, - mg1_np, - rms1_np, - mom1_np, - learningRate, - decay, - momentum, - epsilon, - centered); - - var1_np = result1[VAR_T]; - mg1_np = result1[MG_T]; - rms1_np = result1[RMS_T]; - mom1_np = result1[MOM_T]; - - if (centered) { - session.evaluate(mg0_np, mg0); - session.evaluate(mg0_np, mg0); - } - if (momentum > 0.F) { - session.evaluate(mom0_np, mom0); - session.evaluate(mom1_np, mom1); - } - - /* TODO the values returned from rms slot, do not match what I see in the python test */ - session.evaluate(rms0_np, rms0); - session.evaluate(rms1_np, rms1); - - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); - } - } - } - } - - @Test - public void testWithLearningRateDecay() { - int numSteps = 3; - - for (int run = 0; run < _test_param_values.length; run++) { - try (TestSession session = TestSession.createTestSession(tf_mode)) { - Ops tf = session.getTF(); - session.setEpsilon(1e-2f); - float[] var0_init = {1.0F, 2.0F}; - float[] var1_init = {3.0F, 4.0F}; - float[] grads0_init = {0.1F, 0.2F}; - float[] grads1_init = {0.01F, 0.2F}; - final float epsilon1 = 1e-2F; - - FloatNdArray var0_np = NdArrays.vectorOf(var0_init); - FloatNdArray var1_np = NdArrays.vectorOf(var1_init); - FloatNdArray grads0_np = NdArrays.vectorOf(grads0_init); - FloatNdArray grads1_np = NdArrays.vectorOf(grads1_init); - - Shape shape0 = Shape.of(var0_init.length); - Shape shape1 = Shape.of(var1_init.length); - Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); - Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); - - Assign var0Initializer = tf.assign(var0, tf.constant(var0_init)); - Assign var1Initializer = tf.assign(var1, tf.constant(var1_init)); - - Constant grads0 = tf.constant(grads0_init); - Constant grads1 = tf.constant(grads1_init); - - // learning_rate, rho (decay), momentum, epsilon, centered - float learningRate = (float) (float) _test_param_values[run][0]; - float decay = (float) _test_param_values[run][1]; - float momentum = (float) _test_param_values[run][2]; - float epsilon = (float) _test_param_values[run][3]; - boolean centered = (boolean) _test_param_values[run][4]; - - RMSProp instance = new RMSProp(tf, learningRate, decay, momentum, epsilon, centered); - - /* build the GradsAnvVars */ - List gradsAndVars = new ArrayList<>(); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); - gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - - Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); - - /* initialize the local variables */ - session.run(var0Initializer); - session.run(var1Initializer); - - /** initialize the accumulators */ - session.run(tf.init()); - - /** make sure the variables were initialized properly */ - session.evaluate(var0_init, var0); - session.evaluate(var1_init, var1); - - Variable mg0 = centered ? instance.getSlot(var0.asOutput(), MG).get() : null; - Variable mg1 = centered ? instance.getSlot(var1.asOutput(), MG).get() : null; - Variable mom0 = - momentum > 0.F ? instance.getSlot(var0.asOutput(), MOMENTUM).get() : null; - Variable mom1 = - momentum > 0.F ? instance.getSlot(var1.asOutput(), MOMENTUM).get() : null; - Variable rms0 = instance.getSlot(var0.asOutput(), RMS).get(); - Variable rms1 = instance.getSlot(var1.asOutput(), RMS).get(); - - float[] zeros = {0.0F, 0.0F}; - float[] ones = {1.0F, 1.0F}; // temp to match RMSProp - FloatNdArray mg0_np = NdArrays.vectorOf(zeros); - FloatNdArray mg1_np = NdArrays.vectorOf(zeros); - FloatNdArray rms0_np = NdArrays.vectorOf(ones); - FloatNdArray rms1_np = NdArrays.vectorOf(ones); - FloatNdArray mom0_np = NdArrays.vectorOf(zeros); - FloatNdArray mom1_np = NdArrays.vectorOf(zeros); - - for (int i = 0; i < numSteps; i++) { - session.run(update, instance.getFeedDict()); - FloatNdArray[] result0 = - calc( - var0_np, - grads0_np, - mg0_np, - rms0_np, - mom0_np, - learningRate, - decay, - momentum, - epsilon, - centered); - var0_np = result0[VAR_T]; - mg0_np = result0[MG_T]; - rms0_np = result0[RMS_T]; - mom0_np = result0[MOM_T]; - - FloatNdArray[] result1 = - calc( - var1_np, - grads1_np, - mg1_np, - rms1_np, - mom1_np, - learningRate, - decay, - momentum, - epsilon, - centered); - - var1_np = result1[VAR_T]; - mg1_np = result1[MG_T]; - rms1_np = result1[RMS_T]; - mom1_np = result1[MOM_T]; - - if (centered) { - session.evaluate(mg0_np, mg0); - session.evaluate(mg0_np, mg0); - } - if (momentum > 0.F) { - session.evaluate(mom0_np, mom0); - session.evaluate(mom1_np, mom1); - } - - /* TODO the values returned from rms slot, do not match what I see in the python test */ - session.evaluate(rms0_np, rms0); - session.evaluate(rms1_np, rms1); - - session.evaluate(var0_np, var0); - session.evaluate(var1_np, var1); - - learningRate *= 0.9F; - instance.setLearningRate(learningRate); - } - } - } - } - - FloatNdArray[] calc( - FloatNdArray var_np, - FloatNdArray grad_np, - FloatNdArray mg_np, - FloatNdArray rms_np, - FloatNdArray mom, - float lr, - float decay, - float momentum, - float epsilon, - boolean centered) { - - FloatNdArray[] result = new FloatNdArray[4]; // var_t, mg_t, rms_t, mom_t - result[RMS_T] = calcRMS(rms_np, grad_np, decay); // RMS - - FloatNdArray denom_t; - if (centered) { - result[MG_T] = calcMG(mg_np, grad_np, decay); - // rms_t - mg_t * mg_t - denom_t = ND.sub(result[RMS_T], ND.square(result[MG_T])); - } else { - result[MG_T] = mg_np; - denom_t = rms_np; - } - if (momentum > 0.F) { - // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) - result[MOM_T] = calcMom(momentum, mom, lr, grad_np, denom_t, epsilon); - // var_t = var - mom_t - result[VAR_T] = ND.sub(var_np, result[MOM_T]); - } else { - result[MOM_T] = mom; - result[VAR_T] = calcVar(var_np, grad_np, lr, denom_t, epsilon); - } - - return result; - } - - private FloatNdArray calcRMS(FloatNdArray rms_np, FloatNdArray grad_np, float decay) { - // rms * rho + (1 - rho) * g * g - FloatNdArray rms_rho = ND.mul(rms_np, decay); - FloatNdArray squareG = ND.square(grad_np); - float oneRHO = 1.0F - decay; - FloatNdArray decayG2 = ND.mul(oneRHO, squareG); - FloatNdArray result = ND.add(rms_rho, decayG2); - return result; - } - - private FloatNdArray calcMG(FloatNdArray mg_np, FloatNdArray grad_np, float decay) { - // mg_t = mg * rho + (1 - rho) * g - FloatNdArray mg_rho = ND.mul(mg_np, decay); - float oneRHO = 1.0F - decay; - FloatNdArray decayG = ND.mul(oneRHO, grad_np); - FloatNdArray result = ND.add(mg_rho, decayG); - return result; - } - - private FloatNdArray calcMom( - float momentum, - FloatNdArray mom, - float lr, - FloatNdArray grad_np, - FloatNdArray denom_t, - float epsilon) { - // momentum * mom + lr * g / (np.sqrt(denom_t + epsilon)) - FloatNdArray moMo = ND.mul(momentum, mom); - FloatNdArray dividend = ND.mul(lr, grad_np); - FloatNdArray divisor = ND.sqrt(ND.add(denom_t, epsilon)); - FloatNdArray quotient = ND.div(dividend, divisor); - FloatNdArray result = ND.add(moMo, quotient); - return result; - } - - private FloatNdArray calcVar( - FloatNdArray var_np, FloatNdArray grad_np, float lr, FloatNdArray denom_t, float epsilon) { - // var - lr * g / (np.sqrt(denom_t) + epsilon) - FloatNdArray dividend = ND.mul(lr, grad_np); - FloatNdArray divisor = ND.add(ND.sqrt(denom_t), epsilon); - FloatNdArray quotient = ND.div(dividend, divisor); - FloatNdArray result = ND.sub(var_np, quotient); - return result; - } -} From dddc2975922a8c3322b04a4cc991b06ab5d254a8 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 14 Sep 2020 19:52:43 -0400 Subject: [PATCH 06/14] Reformatted code --- .../java/org/tensorflow/framework/optimizers/AdaDelta.java | 4 ---- .../java/org/tensorflow/framework/optimizers/AdaGrad.java | 5 +---- .../main/java/org/tensorflow/framework/optimizers/Ftrl.java | 4 +--- 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 9453dce7343..be5b39a534d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -68,14 +68,10 @@ public class AdaDelta extends Optimizer { public static final float RHO_DEFAULT = 0.95f; public static final float EPSILON_DEFAULT = 1e-7f; - - private final float rho; private final float epsilon; - - public AdaDelta(Graph graph) { this(graph, LEARNING_RATE_DEFAULT, RHO_DEFAULT, EPSILON_DEFAULT); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 3aae6f71693..384c04e60bb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -139,10 +139,7 @@ private void createAdaGradSlot(Output v) { protected Op applyDense(Output gradient, Output variable) { Variable slot = getSlot(variable, ACCUMULATOR).get(); return tf.train.applyAdagrad( - variable, - slot, - tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), - gradient); + variable, slot, tf.dtypes.cast(getLearningRateOperand(), gradient.dataType()), gradient); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 8e7b638dc21..b455ae1f0be 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -176,7 +176,6 @@ public Ftrl( this.l2RegularizationStrength = l2Strength; this.l2ShrinkageRegularizationStrength = l2ShrinkageRegularizationStrength; validateParams(); - } /** Validates all the settings of the Frtl Optmizer */ @@ -248,8 +247,7 @@ protected Op applyDense(Output gradient, Output variable tf.dtypes.cast(this.getLearningRateOperand(), gradient.dataType()), tf.dtypes.cast(tf.constant(l1RegularizationStrength), gradient.dataType()), tf.dtypes.cast(tf.constant(l2RegularizationStrength), gradient.dataType()), - tf.dtypes.cast( - tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()), + tf.dtypes.cast(tf.constant(l2ShrinkageRegularizationStrength), gradient.dataType()), tf.dtypes.cast(tf.constant(learningRatePower), gradient.dataType()), ApplyFtrl.useLocking(true)); } From cb8104ce7f26456bb3acb3c40b7f4d67fd5d8b90 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 14 Sep 2020 19:54:37 -0400 Subject: [PATCH 07/14] Reformatted code --- .../framework/optimizers/AdaDeltaTest.java | 3 +- .../framework/optimizers/AdaGradDATest.java | 14 ++-- .../framework/optimizers/AdamTest.java | 16 +--- .../framework/optimizers/AdamaxTest.java | 3 +- .../optimizers/GradientDescentTest.java | 5 +- .../framework/optimizers/MomentumTest.java | 3 +- .../framework/optimizers/NadamTest.java | 79 +++++++++---------- .../framework/optimizers/RMSPropTest.java | 3 +- 8 files changed, 58 insertions(+), 68 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java index 37c7bc5ded0..3547ea9a30e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java @@ -235,7 +235,8 @@ public void testWithLearningRateDecay() { float totUpdate = 0; for (int step = 0; step < numSteps; step++) { assertEquals(learningRate, instance.getLearningRate(), epsilon); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(adadeltaUpdate, instance.getFeedMap()); accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); updates[step] = diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java index 8e44f7db0ed..1f8044c1168 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import static org.junit.jupiter.api.Assertions.assertEquals; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -29,6 +28,8 @@ import java.util.ArrayList; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; + /** Test cases for AdaGradDA Optimizer */ public class AdaGradDATest { @@ -140,14 +141,15 @@ public void testWithLearningRateDecay() { session.evaluate(var1Init, var1); float[][][] expected = { - {{ -2.121320f, -2.683281f},{ -0.298511f, -0.588348f}}, - {{ -3.680166f, -4.483282f}, { -0.565851f, -1.107964f}}, - {{ -4.895166f, -5.831203f}, { -0.805286f, -1.567190f}}, - {{ -5.873222f, -6.892054f}, { -1.019739f, -1.973306f}} + {{-2.121320f, -2.683281f}, {-0.298511f, -0.588348f}}, + {{-3.680166f, -4.483282f}, {-0.565851f, -1.107964f}}, + {{-4.895166f, -5.831203f}, {-0.805286f, -1.567190f}}, + {{-5.873222f, -6.892054f}, {-1.019739f, -1.973306f}} }; for (int i = 0; i < numSteps; i++) { assertEquals(learningRate, instance.getLearningRate(), epsilon); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); session.evaluate(expected[i][0], var0); session.evaluate(expected[i][1], var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index c6681fa1557..a8be65c3650 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -274,12 +274,7 @@ public void testWithLearningRateDecay() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> assertEquals(powers[0], f.getFloat(), epsilon1) - ); + result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); } try (Tensor result = session @@ -289,14 +284,11 @@ public void testWithLearningRateDecay() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); } assertEquals(learningRate, instance.getLearningRate(), 1e-6f); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); float lr_t = diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index 1d9648b9b51..57d3cbdb70c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -269,7 +269,8 @@ public void testWithLearningRateDecay() { }); } assertEquals(learningRate, instance.getLearningRate(), epsilon); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0, learningRate); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index ec59a046421..8e793e35d5f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -16,7 +16,6 @@ import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.Momentum.MOMENTUM; /** Test cases for GradientDescent Optimizer */ public class GradientDescentTest { @@ -129,7 +128,6 @@ public void testWithLearningRateDecay() { Op update = instance.applyGradients(gradsAndVars, "GradientDescentTest"); - /* initialize the local variables */ session.run(var0Initializer); session.run(var1Initializer); @@ -157,7 +155,8 @@ public void testWithLearningRateDecay() { }; for (int step = 0; step < numSteps; step++) { assertEquals(learningRate, instance.getLearningRate(), 1e-6f); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); session.evaluate(expectedVar0[step], var0); session.evaluate(expectedVar1[step], var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index ce5ad379629..b54e3b52a26 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -256,7 +256,8 @@ public void testWithLearningRateDecay() { }; for (int step = 0; step < numSteps; step++) { assertEquals(learningRate, instance.getLearningRate(), 1e-6); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); session.evaluate(expectedVar0[step], var0); session.evaluate(expectedVar1[step], var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index fcdd1e3ef7c..c7c17689a33 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -243,7 +243,6 @@ public void testWithLearningRateDecay() { Constant grads0 = tf.constant(grads0Init); Constant grads1 = tf.constant(grads1Init); - /* build the GradsAnvVars */ List> gradsAndVars = new ArrayList<>(); gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); @@ -282,46 +281,40 @@ public void testWithLearningRateDecay() { session.evaluate(var1Init, var1); try (Tensor result = - session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> assertEquals(1F, f.getFloat(), epsilon1)); + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; for (int step = 0; step < numSteps; step++) { assertEquals(learningRate, instance.getLearningRate(), 1e-6f); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); float mut = - Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); momentum = momentum * mut; try (Tensor result = - session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0) - .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> assertEquals(momentum, f.getFloat(), epsilon1)); + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); - FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache, learningRate); + FloatNdArray[] resultsNP = + nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache, learningRate); var0Np = resultsNP[VAR]; m0 = resultsNP[M]; v0 = resultsNP[V]; @@ -349,24 +342,24 @@ public void testWithLearningRateDecay() { } } - private FloatNdArray[] nadamUpdateNdArray( - FloatNdArray varNp, - FloatNdArray gradsNp, - int t, - FloatNdArray m, - FloatNdArray v, - FloatNdArray mCache) { + FloatNdArray varNp, + FloatNdArray gradsNp, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray mCache) { return nadamUpdateNdArray(varNp, gradsNp, t, m, v, mCache, 0.001F); } + private FloatNdArray[] nadamUpdateNdArray( - FloatNdArray varNp, - FloatNdArray gradsNp, - int t, - FloatNdArray m, - FloatNdArray v, - FloatNdArray mCache, - float alpha) { + FloatNdArray varNp, + FloatNdArray gradsNp, + int t, + FloatNdArray m, + FloatNdArray v, + FloatNdArray mCache, + float alpha) { float beta1 = 0.9F; float beta2 = 0.999F; @@ -382,7 +375,7 @@ private FloatNdArray[] nadamUpdateNdArray( FloatNdArray vPrimeT = ND.div(vT, 1.F - (float) Math.pow(beta2, t + 1)); FloatNdArray mBarT = ND.add(ND.mul((1 - muT), gPrimeT), ND.mul(muT1, mPrimeT)); FloatNdArray paramT = - ND.sub(varNp, ND.div(ND.mul(alpha, mBarT), ND.add(ND.sqrt(vPrimeT), epsilon))); + ND.sub(varNp, ND.div(ND.mul(alpha, mBarT), ND.add(ND.sqrt(vPrimeT), epsilon))); FloatNdArray[] results = new FloatNdArray[3]; results[VAR] = paramT; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java index 6d489951c77..2a012ff0f99 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -306,7 +306,8 @@ public void testWithLearningRateDecay() { for (int i = 0; i < numSteps; i++) { assertEquals(learningRate, instance.getLearningRate(), epsilon); - session.evaluate(learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); + session.evaluate( + learningRate, tf.identity(instance.getLearningRateOperand()), instance.getFeedMap()); session.run(update, instance.getFeedMap()); FloatNdArray[] result0 = calc( From 15189b4c90e18eb859c6e708615bc741594eccad Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 14 Sep 2020 19:55:07 -0400 Subject: [PATCH 08/14] Reformatted code --- .../framework/utils/TestSession.java | 170 +++++++++--------- 1 file changed, 83 insertions(+), 87 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index 47c39e820fc..713225a4962 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -65,6 +65,7 @@ public void initialize() { /** * Returns the Graph if in Graph mode, or null if in EagerMode + * * @return the Graph if in Graph mode, or null if in EagerMode */ public Graph getGraph() { @@ -78,19 +79,18 @@ public Graph getGraph() { * * @param op The Operation to run */ - public void run(Op op) { - run(op, null); + public void run(Op op) { + run(op, null); } - /** * Perform session.run() * *

    If in eager mode, this does nothing. * * @param op The Operation to run - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public abstract void run(Op op, Map, Tensor> feedMap); @@ -110,8 +110,8 @@ public void evaluate(Number expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -136,8 +136,8 @@ public void evaluate(Number expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Number expected, Op input, Map, Tensor> feedMap) { @@ -161,14 +161,12 @@ public void evaluate(Number[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type for the input */ public void evaluate( - Number[] expected, - Op input, - Map, Tensor> feedMap) { + Number[] expected, Op input, Map, Tensor> feedMap) { Output output = input.op().output(0); evaluate(expected, output, feedMap); } @@ -190,8 +188,8 @@ public void evaluate(Number[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -218,8 +216,8 @@ public void evaluate(byte expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -245,8 +243,8 @@ public void evaluate(int expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -272,8 +270,8 @@ public void evaluate(long expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -299,8 +297,8 @@ public void evaluate(float expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -326,8 +324,8 @@ public void evaluate(double expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public abstract void evaluate( @@ -351,8 +349,8 @@ public void evaluate(byte[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -382,8 +380,8 @@ public void evaluate(int[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -413,8 +411,8 @@ public void evaluate(long[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -444,8 +442,8 @@ public void evaluate(float[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -475,8 +473,8 @@ public void evaluate(double[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -506,8 +504,8 @@ public void evaluate(Number[] expected, Output input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public abstract void evaluate( @@ -530,8 +528,8 @@ public void evaluate(String expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( String expected, @@ -555,8 +553,8 @@ public void evaluate(String expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( String expected, Op input, Map, Tensor> feedMap) { @@ -578,13 +576,11 @@ public void evaluate(String[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( - String[] expected, - Op input, - Map, Tensor> feedMap) { + String[] expected, Op input, Map, Tensor> feedMap) { Output output = input.op().output(0); evaluate(expected, output, feedMap); } @@ -605,8 +601,8 @@ public void evaluate(String[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public abstract void evaluate( String[] expected, @@ -628,8 +624,8 @@ public void evaluate(Boolean expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Boolean expected, @@ -653,8 +649,8 @@ public void evaluate(Boolean expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Boolean expected, Op input, Map, Tensor> feedMap) { @@ -677,8 +673,8 @@ public void evaluate(Boolean[] expected, Op input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Boolean[] expected, @@ -704,8 +700,8 @@ public void evaluate(Boolean[] expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void evaluate( Boolean[] expected, @@ -730,8 +726,8 @@ public void evaluate(Boolean[] expected, Output input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public abstract void evaluate( Boolean[] expected, @@ -765,8 +761,8 @@ public void evaluate(Operand expected, Operand input) { * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type for the feedMap entries */ public abstract void evaluate( @@ -790,8 +786,8 @@ public void evaluate(FloatNdArray expected, Operand input * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public void evaluate( @@ -817,8 +813,8 @@ public void evaluate(FloatNdArray expected, Output input) * * @param expected the expected value * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public abstract void evaluate( @@ -844,8 +840,8 @@ public void evaluate(Operand input, Predicate pre * @param input the actual value * @param predicate a predicate that accepts a Number as an argument, if the result of the * predicate is false, then the test will fail - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type of the input */ public abstract void evaluate( @@ -878,8 +874,8 @@ public void print(Operand input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type for the feedMap entries */ public void print( @@ -900,8 +896,8 @@ public void print(Op input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void print(Op input, Map, Tensor> feedMap) { print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0), feedMap); @@ -921,8 +917,8 @@ public void print(Output input) { * Print the results to the "standard" output stream. * * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type for the input */ public void print( @@ -946,8 +942,8 @@ public void print(OutputStream out, Operand input) { * * @param out the output stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type for the feedMap entries */ public void print( @@ -972,8 +968,8 @@ public void print(OutputStream out, Op input) { * * @param out the output stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void print( OutputStream out, Op input, Map, Tensor> feedMap) { @@ -996,8 +992,8 @@ public void print(OutputStream out, Output input) { * * @param out the output stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type for the input */ public void print( @@ -1023,8 +1019,8 @@ public void print(Writer writer, Operand input) { * * @param writer the character stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type for the input */ public void print( @@ -1049,8 +1045,8 @@ public void print(Writer writer, Op input) { * * @param writer the character stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. */ public void print( Writer writer, Op input, Map, Tensor> feedMap) { @@ -1073,8 +1069,8 @@ public void print(Writer writer, Output input) { * * @param writer the character stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type for the input */ public void print( @@ -1100,8 +1096,8 @@ public void print(PrintWriter writer, Output input) { * * @param writer the character stream * @param input the actual value - * @param feedMap The dictionary of values to pass to the feed() operation of the runner, - * required for placeholders. + * @param feedMap The dictionary of values to pass to the feed() operation of the runner, required + * for placeholders. * @param the data type for the input */ public abstract void print( From eb2c48e0695b03fc8a71b880f5029c0ec8abb44c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 15 Sep 2020 08:13:23 -0400 Subject: [PATCH 09/14] Remove premature commit --- .../schedules/PiecewiseConstantDecay.java | 58 -------- .../optimizers/schedules/PolynomialDecay.java | 127 ------------------ .../schedules/PiecewiseConstantDecayTest.java | 16 --- .../schedules/PolynomialDecayTest.java | 24 ---- 4 files changed, 225 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java deleted file mode 100644 index 43f85fa0ff1..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecay.java +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.optimizers.schedules;; - -/** - * A LearningRateSchedule that uses a piecewise constant decay schedule. - *

    - *

    The function computes the piecewise constant - when passed the current optimizer step. This can be useful for changing the - learning rate value across different invocations of optimizer functions. - *

    - *

    Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 - for the next 10000 steps, and 0.1 for any additional steps. - */ -public class PiecewiseConstantDecay implements LearningRateSchedule { - private float[] boundaries; - private float[] values; - - private int lastIndex = 0; - - /** - * Create an PiecewiseConstantDecay - * - * @param boundaries An array of with strictly increasing entries - * @param values An array that specifies the - values for the intervals defined by boundaries. It should have one - more element than boundaries. - * @throws java.lang.IllegalArgumentException if the the length of values does not have 1 more element than boundaries. - */ - public PiecewiseConstantDecay(float[] boundaries, float[] values) { - if(boundaries.length != values.length - 1) { - throw new IllegalArgumentException("The length of boundaries should be 1 less than the length of values"); - } - this.boundaries = boundaries; - this.values = values; - } - - - @Override - public float call(int step) { - if(lastIndex < boundaries.length && step > boundaries[lastIndex]) - lastIndex++; - return values[lastIndex]; - } - -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java deleted file mode 100644 index 0988577c38f..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecay.java +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.optimizers.schedules; - -/** - * A LearningRateSchedule that uses a polynomial decay schedule. - * - *

    - * - *

    It is commonly observed that a monotonically decreasing learning rate, whose degree of change - * is carefully chosen, results in a better performing model. This schedule applies a polynomial - * decay function to an optimizer step, given a provided `initial_learning_rate`, to reach an - * `end_learning_rate` in the given `decay_steps`. - * - *

    - * - *

    The schedule is a 1-arg callable that produces a decayed learning rate when passed the current - * optimizer step. This can be useful for changing the learning rate value across different - * invocations of optimizer functions. It is computed as: - * - *

    - *     step = min(step, decay_steps)
    - *     ((initialLearningRate - endLearningRate) *
    - * (1 - step / decaySteps) ^ (power)
    - * ) + endLearningRate
    - * 
    - * - *

    - * - *

    If `cycle` is True then a multiple of `decay_steps` is used, the first one that is bigger than - * `step`. - */ -public class PolynomialDecay implements LearningRateSchedule { - private static final float END_LEARNING_RATE_DEFAULT = 0.0001f; - public static final float POWER_DEFAULT = 1.0f; - public static final boolean CYCLE_DEFAULT = false; - - protected final float initialLearningRate; - protected final float decaySteps; - protected final float endLearningRate; - protected final float power; - protected final boolean cycle; - - /** - * Create a PolynomialDecay - * - * @param initialLearningRate The initial learning rate. - * @param decaySteps How often to apply decay. - */ - public PolynomialDecay(float initialLearningRate, int decaySteps) { - this(initialLearningRate, decaySteps, END_LEARNING_RATE_DEFAULT, POWER_DEFAULT, CYCLE_DEFAULT); - } - - /** - * Create a PolynomialDecay - * - * @param initialLearningRate The initial learning rate. - * @param decaySteps How often to apply decay. - * @param cycle Whether or not it should cycle beyond decay_steps. Default is false. - */ - public PolynomialDecay(float initialLearningRate, int decaySteps, boolean cycle) { - this(initialLearningRate, decaySteps, END_LEARNING_RATE_DEFAULT, POWER_DEFAULT, cycle); - } - - /** - * Create a PolynomialDecay - * - * @param initialLearningRate The initial learning rate. - * @param decaySteps How often to apply decay. - * @param endLearningRate The end learning rate. Default is 0.0001. - */ - public PolynomialDecay(float initialLearningRate, int decaySteps, float endLearningRate) { - this(initialLearningRate, decaySteps, endLearningRate, POWER_DEFAULT, CYCLE_DEFAULT); - } - - /** - * Create a PolynomialDecay - * - * @param initialLearningRate The initial learning rate. - * @param decaySteps How often to apply decay. - * @param endLearningRate The end learning rate. Default is 0.0001. - * @param power The power of the polynomial. Defaults to linear, 1.0. - * @param cycle Whether or not it should cycle beyond decay_steps. Default is false. - */ - public PolynomialDecay( - float initialLearningRate, - int decaySteps, - float endLearningRate, - float power, - boolean cycle) { - this.initialLearningRate = initialLearningRate; - this.decaySteps = decaySteps; - this.endLearningRate = endLearningRate; - this.power = power; - this.cycle = cycle; - } - - @Override - public float call(int step) { - - float lDecaySteps = decaySteps; - float lStep = step; - if (cycle) { - float multipler = step == 0 ? 1.0f : (float) Math.ceil(step / decaySteps); - lDecaySteps = decaySteps * multipler; - } else { - lStep = Math.min(lStep, lDecaySteps); - } - - float p = lStep / lDecaySteps; - - float f = (this.initialLearningRate - this.endLearningRate) * (float) Math.pow(1.0f - p, power); - return f + endLearningRate; - } -} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java deleted file mode 100644 index dac8caa19a3..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PiecewiseConstantDecayTest.java +++ /dev/null @@ -1,16 +0,0 @@ -package org.tensorflow.framework.optimizers.schedules; - -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.*; - -class PiecewiseConstantDecayTest { - - public PiecewiseConstantDecayTest() {} - - @Test - public void testDecay() { - - } - -} \ No newline at end of file diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java deleted file mode 100644 index a28e56ad7cb..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/schedules/PolynomialDecayTest.java +++ /dev/null @@ -1,24 +0,0 @@ -package org.tensorflow.framework.optimizers.schedules; - -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.*; - -class PolynomialDecayTest { - - public PolynomialDecayTest() {} - - @Test - public void testBeginWithCycle() { - float initialLearningRate = 0.1f; - int decaySteps = 10; - float decayRate = 0.96f; - float epsilon = 1e-6f; - PolynomialDecay instance = new PolynomialDecay(initialLearningRate, decaySteps, true); - float expected = initialLearningRate; - float actual = instance.call(0); - assertEquals(expected, actual, epsilon); - - } - -} \ No newline at end of file From d5edd353059839bcb207b5a88d1c53eb31cc8e3c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 19 Sep 2020 18:45:40 -0400 Subject: [PATCH 10/14] Added JavaDoc back in, changed setLearningRate() to setLearningRate(newLearningRate), eliminated spurious "this." --- .../framework/optimizers/Optimizer.java | 71 +++++++++++++++---- 1 file changed, 57 insertions(+), 14 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 8e0471dc0ba..7c5258348c2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -111,20 +111,47 @@ protected Optimizer(Graph graph, String name, float learningRate) { setLearningRate(learningRate); } + /** + * Creates a name by combining a variable name and a slot name + * + * @param variable the variable + * @param slotName the name of the slot + * @return the combined name + */ public static String createName(Output variable, String slotName) { return variable.op().name() + "-" + slotName; } + /** + * Minimizes the loss by updating the variables + * + * @param loss the loss operation that returns the value to minimize + * @return returns op that minimizes the loss by updating the listed variables + */ public Op minimize(Operand loss) { return minimize(loss, getOptimizerName() + "-minimize"); } + /** + * Minimizes the loss by updating the variables + * + * @param loss the loss operation that returns the value to minimize + * @param name the name for the minimize operation + * @return op that minimizes the loss by updating the listed variables + */ public Op minimize(Operand loss, String name) { List> gradsAndVars = computeGradients(loss); return applyGradients(gradsAndVars, name); } + /** + * Computes the gradients based on a loss operand. + * + * @param loss the loss operation + * @param the data type of the loss, gradients and variables. + * @return the computed gradients + */ public List> computeGradients(Operand loss) { List variables = new ArrayList<>(); graph @@ -156,6 +183,13 @@ public List> computeGradients(Operand loss) { return gradVarPairs; } + /** + * Applies gradients to variables + * + * @param gradsAndVars the list of (gradient, variable) pairs. + * @param name the name of the apply gradients operation + * @return an Op that applies the gradients to the variables. + */ public Op applyGradients(List> gradsAndVars, String name) { List> variables = gradsAndVars.stream().map(GradAndVar::getVariable).collect(Collectors.toList()); @@ -242,6 +276,13 @@ protected Optional prepare(String scopeName) { */ protected void createSlots(List> variables) {} + /** + * Generates the gradient update operations for the specific variable and gradient. + * + * @param gradVarPair the list of (gradient, variable) pairs. + * @param the datatype of the gradients and variables. + * @return An operand which applies the desired optimizer update to the variable. + */ private Op applyDense(GradAndVar gradVarPair) { return applyDense(gradVarPair.getGradient(), gradVarPair.getVariable()); } @@ -280,20 +321,20 @@ protected Op finish(List updateOperations, String name) { /** * Sets the learning rate * - * @param learningRate the learning rate + * @param newLearningRate the new earning rate */ - public final void setLearningRate(float learningRate) { - if (this.learningRatePlaceholder == null) { - this.learningRatePlaceholder = + public final void setLearningRate(float newLearningRate) { + if (learningRatePlaceholder == null) { + learningRatePlaceholder = tf.withSubScope(LEARNING_RATE) .placeholder(TFloat32.DTYPE, Placeholder.shape(Shape.scalar())); } - if (this.learningRate != learningRate) { - if (this.learningRateTensor != null) this.learningRateTensor.close(); - this.learningRate = learningRate; - this.learningRateTensor = TFloat32.scalarOf(this.learningRate); - this.feedMap = Collections.singletonMap(this.learningRatePlaceholder, learningRateTensor); + if (learningRate != newLearningRate) { + if (learningRateTensor != null) learningRateTensor.close(); + learningRate = newLearningRate; + learningRateTensor = TFloat32.scalarOf(learningRate); + feedMap = Collections.singletonMap(learningRatePlaceholder, learningRateTensor); } } @@ -303,7 +344,7 @@ public final void setLearningRate(float learningRate) { * @return the learning rate */ public float getLearningRate() { - return this.learningRate; + return learningRate; } /** @@ -312,7 +353,7 @@ public float getLearningRate() { * @return the learning rate Operand */ protected Operand getLearningRateOperand() { - return this.learningRatePlaceholder; + return learningRatePlaceholder; } /** @@ -323,13 +364,15 @@ protected Operand getLearningRateOperand() { * Operand has been set. */ public Map, Tensor> getFeedMap() { - return this.feedMap; + return feedMap; } + /** {@inheritDoc} */ public void close() { // close the learningRate Tensor if it exists. - if (this.feedMap != null) { - this.feedMap.get(this.learningRatePlaceholder).close(); + if (learningRateTensor != null) { + learningRateTensor.close(); + learningRateTensor = null; } } From 2f57c1df6c3eaeec39a64b87da8cf5a66525904a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 21 Sep 2020 15:00:45 -0400 Subject: [PATCH 11/14] Change Optimizer to only have one constructor, "protected Optimizer(Graph graph, String name, float learningRate)"", change all the subclass ctors to use this one. --- .../framework/optimizers/AdaDelta.java | 4 +- .../framework/optimizers/AdaGrad.java | 9 +--- .../framework/optimizers/AdaGradDA.java | 18 +------ .../tensorflow/framework/optimizers/Adam.java | 5 +- .../framework/optimizers/Adamax.java | 5 +- .../tensorflow/framework/optimizers/Ftrl.java | 16 +++--- .../framework/optimizers/GradientDescent.java | 2 +- .../framework/optimizers/Momentum.java | 4 +- .../framework/optimizers/Nadam.java | 5 +- .../framework/optimizers/Optimizer.java | 49 +------------------ .../framework/optimizers/RMSProp.java | 6 +-- 11 files changed, 20 insertions(+), 103 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index be5b39a534d..9f2e868ea1a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -95,9 +95,7 @@ public AdaDelta(Graph graph, float learningRate) { * @param epsilon A constant epsilon used to better conditioning the grad update */ public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { - super(graph, learningRate); - this.rho = rho; - this.epsilon = epsilon; + this(graph, null, learningRate, rho, epsilon); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 384c04e60bb..9a9498630a8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -74,14 +74,7 @@ public AdaGrad(Graph graph, float learningRate) { * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative */ public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { - super(graph, learningRate); - if (initialAccumulatorValue < 0F) { - throw new IllegalArgumentException( - String.format( - "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue)); - } - this.learningRate = learningRate; - this.initialAccumulatorValue = initialAccumulatorValue; + this(graph, null, learningRate, initialAccumulatorValue); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index 48823bf5fd8..af5c3737251 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -100,23 +100,7 @@ public AdaGradDA( float initialAccumulatorValue, float l1Strength, float l2Strength) { - super(graph, learningRate); - if (initialAccumulatorValue <= 0F) { - throw new IllegalArgumentException( - String.format( - "initialAccumulatorValue must be greater than zero: %f", initialAccumulatorValue)); - } - if (l1Strength < 0F) { - throw new IllegalArgumentException( - String.format("l1Strength must not be negative: %f", l1Strength)); - } - if (l2Strength < 0F) { - throw new IllegalArgumentException( - String.format("l2Strength must not be negative: %f", l2Strength)); - } - this.initialAccumulatorValue = initialAccumulatorValue; - this.l1Strength = l1Strength; - this.l2Strength = l2Strength; + this(graph, null, learningRate, initialAccumulatorValue, l1Strength, l2Strength); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index b5915b4bd57..3ca9fbdab57 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -99,10 +99,7 @@ public Adam(Graph graph, float learningRate) { * 1 of the paper. Defaults to 1e-8. */ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph, learningRate); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; + this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index e568f881773..c381013e97c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -92,10 +92,7 @@ public Adamax(Graph graph, String name, float learningRate) { * @param epsilon A small constant for numerical stability. */ public Adamax(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph, learningRate); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; + this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index b455ae1f0be..edbe91c62e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -132,13 +132,15 @@ public Ftrl( float l1Strength, float l2Strength, float l2ShrinkageRegularizationStrength) { - super(graph, learningRate); - this.learningRatePower = learningRatePower; - this.initialAccumulatorValue = initialAccumulatorValue; - this.l1RegularizationStrength = l1Strength; - this.l2RegularizationStrength = l2Strength; - this.l2ShrinkageRegularizationStrength = l2ShrinkageRegularizationStrength; - validateParams(); + this( + graph, + null, + learningRate, + learningRatePower, + initialAccumulatorValue, + l1Strength, + l2Strength, + l2ShrinkageRegularizationStrength); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index 36f36057c26..f57503d3347 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -44,7 +44,7 @@ public GradientDescent(Graph graph) { * @param learningRate the learning rate, defaults to 0.01 */ public GradientDescent(Graph graph, float learningRate) { - super(graph, learningRate); + super(graph, null, learningRate); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index a099eae53e8..19e3f275f1f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -84,9 +84,7 @@ public Momentum(Graph graph, float learningRate, float momentum) { * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. */ public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { - super(graph, learningRate); - this.momentum = momentum; - this.useNesterov = useNesterov; + this(graph, null, learningRate, momentum, useNesterov); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index d0228eb8b3a..ece7c024969 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -94,10 +94,7 @@ public Nadam(Graph graph, float learningRate) { * @param epsilon A small constant for numerical stability. Default is 1e-8. */ public Nadam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { - super(graph, learningRate); - this.betaOne = betaOne; - this.betaTwo = betaTwo; - this.epsilon = epsilon; + this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 7c5258348c2..868d04672f1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -51,51 +51,6 @@ public abstract class Optimizer implements AutoCloseable { private Tensor learningRateTensor; private Map, Tensor> feedMap = null; - /** - * Builds an optimizer for the supplied graph. - * - *

    Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. - * - * @param graph The graph to optimize. - */ - protected Optimizer(Graph graph) { - this.graph = graph; - this.tf = Ops.create(graph).withName(getOptimizerName()); - this.slots = new HashMap<>(); - this.globals = new ArrayList<>(); - setLearningRate(LEARNING_RATE_DEFAULT); - } - - /** - * Builds an optimizer for the supplied graph. - * - *

    Uses the name from {@link Optimizer#getOptimizerName()} to name the operations. - * - * @param graph The graph to optimize. - * @param learningRate the learning rate. - */ - protected Optimizer(Graph graph, float learningRate) { - this.graph = graph; - this.tf = Ops.create(graph).withName(getOptimizerName()); - this.slots = new HashMap<>(); - this.globals = new ArrayList<>(); - setLearningRate(learningRate); - } - - /** - * Builds an optimizer for the supplied graph. - * - * @param graph The graph to optimize. - * @param name The base name for the operations. - */ - protected Optimizer(Graph graph, String name) { - this.graph = graph; - this.tf = Ops.create(graph).withName(name); - this.slots = new HashMap<>(); - this.globals = new ArrayList<>(); - setLearningRate(LEARNING_RATE_DEFAULT); - } - /** * Builds an optimizer for the supplied graph. * @@ -105,7 +60,7 @@ protected Optimizer(Graph graph, String name) { */ protected Optimizer(Graph graph, String name, float learningRate) { this.graph = graph; - this.tf = Ops.create(graph).withName(name); + this.tf = Ops.create(graph).withName(name == null ? getOptimizerName() : name); this.slots = new HashMap<>(); this.globals = new ArrayList<>(); setLearningRate(learningRate); @@ -367,7 +322,7 @@ public Map, Tensor> getFeedMap() { return feedMap; } - /** {@inheritDoc} */ + /** {@inheritDoc} */ public void close() { // close the learningRate Tensor if it exists. if (learningRateTensor != null) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index face906d682..41b65a0ac01 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -106,11 +106,7 @@ public RMSProp( float momentum, float epsilon, boolean centered) { - super(graph, learningRate); - this.decay = decay; - this.momentum = momentum; - this.epsilon = epsilon; - this.centered = centered; + this(graph, null, learningRate, decay, momentum, epsilon, centered); } /** From e9e2b24608f83de551608aa935caffff328a71d9 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 23 Sep 2020 10:53:52 -0400 Subject: [PATCH 12/14] Fixed close() routine to free up closed tensor in feedMap by setting feedMap to null. --- .../java/org/tensorflow/framework/optimizers/Optimizer.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 868d04672f1..5194cb32e73 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -276,7 +276,7 @@ protected Op finish(List updateOperations, String name) { /** * Sets the learning rate * - * @param newLearningRate the new earning rate + * @param newLearningRate the new learning rate */ public final void setLearningRate(float newLearningRate) { if (learningRatePlaceholder == null) { @@ -315,7 +315,7 @@ protected Operand getLearningRateOperand() { * Gets the Feed Map for the run methods to set the Placeholder value(s). Each entry in the Feed * Map contains a PlaceHolder and a Tensor with the value * - * @return the current Feed Map for the run methods, this may be null if an LearningRate as an + * @return the current Feed Map for the run methods, this may be null if the LearningRate is an * Operand has been set. */ public Map, Tensor> getFeedMap() { @@ -329,6 +329,7 @@ public void close() { learningRateTensor.close(); learningRateTensor = null; } + if (feedMap != null) feedMap = null; } /** Optional attributes for {@link org.tensorflow.framework.optimizers.Optimizer} */ From 62ff85c170983312b98e0d00561ae743db2b2e4c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 25 Sep 2020 11:02:05 -0400 Subject: [PATCH 13/14] Fix javadoc for references to Default values, Add Operand learningRateOperand as an option for learning rate. --- .../framework/optimizers/AdaDelta.java | 78 ++++++++- .../framework/optimizers/AdaGrad.java | 86 +++++++++- .../framework/optimizers/AdaGradDA.java | 121 ++++++++++++- .../tensorflow/framework/optimizers/Adam.java | 105 ++++++++++-- .../framework/optimizers/Adamax.java | 100 ++++++++++- .../tensorflow/framework/optimizers/Ftrl.java | 159 +++++++++++++++++- .../framework/optimizers/GradientDescent.java | 39 ++++- .../framework/optimizers/Momentum.java | 99 +++++++++-- .../framework/optimizers/Nadam.java | 111 +++++++++--- .../framework/optimizers/Optimizer.java | 40 ++++- .../framework/optimizers/RMSProp.java | 129 ++++++++++++-- 11 files changed, 964 insertions(+), 103 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 9f2e868ea1a..30abd0fcbe3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -62,6 +63,7 @@ */ public class AdaDelta extends Optimizer { + public static final String DEFAULT_NAME = "Adadelta"; public static final String ACCUMULATOR = "accum"; public static final String ACCUMULATOR_UPDATE = "accum_update"; public static final float LEARNING_RATE_DEFAULT = 0.001f; @@ -72,12 +74,20 @@ public class AdaDelta extends Optimizer { private final float epsilon; + /** + * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learningRate, {@link #RHO_DEFAULT} for the rho, and {@link + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph. + */ public AdaDelta(Graph graph) { this(graph, LEARNING_RATE_DEFAULT, RHO_DEFAULT, EPSILON_DEFAULT); } /** - * Creates an AdaDelta Optimizer + * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #RHO_DEFAULT} for the rho, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -87,7 +97,19 @@ public AdaDelta(Graph graph, float learningRate) { } /** - * Creates an AdaDelta Optimizer + * Creates an AdaDelta Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #RHO_DEFAULT} for the rho, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaDelta(Graph graph, Operand learningRateOperand) { + this(graph, learningRateOperand, RHO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates an AdaDelta Optimizer {@link #DEFAULT_NAME} for the Optimizer name * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -102,18 +124,45 @@ public AdaDelta(Graph graph, float learningRate, float rho, float epsilon) { * Creates an AdaDelta Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adadelta') + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param rho The decay factor + * @param epsilon A constant epsilon used to better conditioning the grad update + */ + public AdaDelta(Graph graph, Operand learningRateOperand, float rho, float epsilon) { + this(graph, null, learningRateOperand, rho, epsilon); + } + + /** + * Creates an AdaDelta Optimizer using {@link #RHO_DEFAULT} for the rho, and {@link * + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. * @param learningRate the learning rate */ public AdaDelta(Graph graph, String name, float learningRate) { - this(graph, name, learningRate, 0.95f, 1e-8f); + this(graph, name, learningRate, RHO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates an AdaDelta Optimizer using {@link #RHO_DEFAULT} for the rho, and {@link * + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaDelta(Graph graph, String name, Operand learningRateOperand) { + this(graph, name, learningRateOperand, RHO_DEFAULT, EPSILON_DEFAULT); } /** * Creates an AdaDelta Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adadelta') + * @param name the name for this Optimizer. * @param learningRate the learning rate * @param rho The decay factor * @param epsilon A constant epsilon used to better conditioning the grad update @@ -124,6 +173,23 @@ public AdaDelta(Graph graph, String name, float learningRate, float rho, float e this.epsilon = epsilon; } + /** + * Creates an AdaDelta Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param rho The decay factor + * @param epsilon A constant epsilon used to better conditioning the grad update + */ + public AdaDelta( + Graph graph, String name, Operand learningRateOperand, float rho, float epsilon) { + super(graph, name, learningRateOperand); + this.rho = rho; + this.epsilon = epsilon; + } + /** {@inheritDoc} */ @Override protected void createSlots(List> variables) { @@ -178,6 +244,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adadelta"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 9a9498630a8..c0cc47409d7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -40,6 +41,8 @@ */ public class AdaGrad extends Optimizer { + public static final String DEFAULT_NAME = "Adagrad"; + public static final String ACCUMULATOR = "accumulator"; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f; @@ -47,7 +50,9 @@ public class AdaGrad extends Optimizer { private final float initialAccumulatorValue; /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, and {@link * #INITIAL_ACCUMULATOR_DEFAULT} for + * the initialAccumulatorValue. * * @param graph the TensorFlow Graph */ @@ -56,7 +61,8 @@ public AdaGrad(Graph graph) { } /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using using {@link #DEFAULT_NAME} for the Optimizer name, {@link * + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -66,7 +72,19 @@ public AdaGrad(Graph graph, float learningRate) { } /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using using {@link #DEFAULT_NAME} for the Optimizer name, {@link * + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGrad(Graph graph, Operand learningRateOperand) { + this(graph, learningRateOperand, INITIAL_ACCUMULATOR_DEFAULT); + } + + /** + * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -78,21 +96,49 @@ public AdaGrad(Graph graph, float learningRate, float initialAccumulatorValue) { } /** - * Creates an AdaGrad Optimizer + * Creates an AdaGrad Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative + */ + public AdaGrad( + Graph graph, Operand learningRateOperand, float initialAccumulatorValue) { + this(graph, null, learningRateOperand, initialAccumulatorValue); + } + + /** + * Creates an AdaGrad Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue. * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adagrad') + * @param name the name for this Optimizer . * @param learningRate the learning rate */ public AdaGrad(Graph graph, String name, float learningRate) { - this(graph, name, learningRate, 0.01f); + this(graph, name, learningRate, INITIAL_ACCUMULATOR_DEFAULT); + } + + /** + * Creates an AdaGrad Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGrad(Graph graph, String name, Operand learningRateOperand) { + this(graph, name, learningRateOperand, INITIAL_ACCUMULATOR_DEFAULT); } /** * Creates an AdaGrad Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'Adagrad') + * @param name the name for this Optimizer * @param learningRate the learning rate * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative. * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative @@ -107,6 +153,30 @@ public AdaGrad(Graph graph, String name, float learningRate, float initialAccumu this.initialAccumulatorValue = initialAccumulatorValue; } + /** + * Creates an AdaGrad Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be non-negative. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is negative + */ + public AdaGrad( + Graph graph, + String name, + Operand learningRateOperand, + float initialAccumulatorValue) { + super(graph, name, learningRateOperand); + if (initialAccumulatorValue < 0F) { + throw new IllegalArgumentException( + String.format( + "initialAccumulatorValue must be non-negative: %f", initialAccumulatorValue)); + } + this.initialAccumulatorValue = initialAccumulatorValue; + } + /** {@inheritDoc} */ @Override protected void createSlots(List> variables) { @@ -149,6 +219,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adagrad"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index af5c3737251..0e070f2f4fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -22,6 +22,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -46,6 +47,7 @@ */ public class AdaGradDA extends Optimizer { + public static final String DEFAULT_NAME = "adagrad-da"; public static final String ACCUMULATOR = "gradient_accumulator"; public static final String SQUARED_ACCUMULATOR = "gradient_squared_accumulator"; public static final float LEARNING_RATE_DEFAULT = 0.001F; @@ -59,7 +61,10 @@ public class AdaGradDA extends Optimizer { private Variable globalStep; /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link + * #L2_STRENGTH_DEFAULT} for the l2Strength. * * @param graph the TensorFlow Graph */ @@ -73,7 +78,9 @@ public AdaGradDA(Graph graph) { } /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for + * the l1Strength, and {@link #L2_STRENGTH_DEFAULT} for the l2Strength. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -84,7 +91,25 @@ public AdaGradDA(Graph graph, float learningRate) { } /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #INITIAL_ACCUMULATOR_DEFAULT} for the initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for + * the l1Strength, and {@link #L2_STRENGTH_DEFAULT} for the l2Strength. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGradDA(Graph graph, Operand learningRateOperand) { + this( + graph, + learningRateOperand, + INITIAL_ACCUMULATOR_DEFAULT, + L1_STRENGTH_DEFAULT, + L2_STRENGTH_DEFAULT); + } + + /** + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -104,10 +129,33 @@ public AdaGradDA( } /** - * Creates an AdaGradDA Optimizer + * Creates an AdaGradDA Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'adagrad-da') + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be greater than zero. + * @param l1Strength l1 regularization strength, must be greater than or equal to zero. + * @param l2Strength l2 regularization strength, must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is not greater than zero, + * or l1Strength or l2Strength is less than zero + */ + public AdaGradDA( + Graph graph, + Operand learningRateOperand, + float initialAccumulatorValue, + float l1Strength, + float l2Strength) { + this(graph, null, learningRateOperand, initialAccumulatorValue, l1Strength, l2Strength); + } + + /** + * Creates an AdaGradDA Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link + * #L2_STRENGTH_DEFAULT} for the l2Strength. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. * @param learningRate the learning rate */ public AdaGradDA(Graph graph, String name, float learningRate) { @@ -120,11 +168,31 @@ public AdaGradDA(Graph graph, String name, float learningRate) { L2_STRENGTH_DEFAULT); } + /** + * Creates an AdaGradDA Optimizer using {@link #INITIAL_ACCUMULATOR_DEFAULT} for the + * initialAccumulatorValue, {@link #L1_STRENGTH_DEFAULT} for the l1Strength, and {@link + * #L2_STRENGTH_DEFAULT} for the l2Strength. + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public AdaGradDA(Graph graph, String name, Operand learningRateOperand) { + this( + graph, + name, + learningRateOperand, + INITIAL_ACCUMULATOR_DEFAULT, + L1_STRENGTH_DEFAULT, + L2_STRENGTH_DEFAULT); + } + /** * Creates an AdaGradDA Optimizer * * @param graph the TensorFlow Graph - * @param name the name for this Optimizer (defaults to 'adagrad-da') + * @param name the name for this Optimizer. * @param learningRate the learning rate * @param initialAccumulatorValue Starting value for the accumulators, must be positive * @param l1Strength l1 regularization strength, must be greater than or equal to zero. @@ -158,6 +226,45 @@ public AdaGradDA( this.l2Strength = l2Strength; } + /** + * Creates an AdaGradDA Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param initialAccumulatorValue Starting value for the accumulators, must be positive + * @param l1Strength l1 regularization strength, must be greater than or equal to zero. + * @param l2Strength l2 regularization strength, must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if initialAccumulatorValue is not greater than zero, + * or * l1Strength or l2Strength is less than zero + */ + public AdaGradDA( + Graph graph, + String name, + Operand learningRateOperand, + float initialAccumulatorValue, + float l1Strength, + float l2Strength) { + super(graph, name, learningRateOperand); + if (initialAccumulatorValue <= 0F) { + throw new IllegalArgumentException( + String.format( + "initialAccumulatorValue must be greater than zero: %f", initialAccumulatorValue)); + } + if (l1Strength < 0F) { + throw new IllegalArgumentException( + String.format("l1Strength must not be negative: %f", l1Strength)); + } + if (l2Strength < 0F) { + throw new IllegalArgumentException( + String.format("l2Strength must not be negative: %f", l2Strength)); + } + this.initialAccumulatorValue = initialAccumulatorValue; + this.l1Strength = l1Strength; + this.l2Strength = l2Strength; + } + /** {@inheritDoc} */ @Override protected Optional prepare(String name) { @@ -240,6 +347,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "adagrad-da"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index 3ca9fbdab57..9e4f41f1039 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -48,6 +48,8 @@ @Operator public class Adam extends Optimizer { + public static final String DEFAULT_NAME = "Adam"; + public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -69,7 +71,9 @@ public class Adam extends Optimizer { private Variable betaTwoPower; /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, + * {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph */ @@ -78,7 +82,9 @@ public Adam(Graph graph) { } /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph * @param learningRate the learning rate @@ -86,44 +92,91 @@ public Adam(Graph graph) { public Adam(Graph graph, float learningRate) { this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adam(Graph graph, Operand learningRateOperand) { + this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph * @param learningRate the learning rate - * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9. - * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm - * 1 of the paper. Defaults to 1e-8. + * 1 of the paper.. */ public Adam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** - * Creates an Adam optimizer + * Creates an Adam optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph - * @param name the Optimizer name, defaults to "Adam" + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. + * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the + * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm + * 1 of the paper. + */ + public Adam( + Graph graph, + Operand learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon); + } + + /** + * Creates an Adam optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name the Optimizer name. * @param learningRate the learning rate */ public Adam(Graph graph, String name, float learningRate) { this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates an Adam optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name the Optimizer name. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adam(Graph graph, String name, Operand learningRateOperand) { + this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + /** * Creates an Adam optimizer * * @param graph the TensorFlow graph - * @param name the Optimizer name, defaults to "Adam" + * @param name the Optimizer name. * @param learningRate the learning rate - * @param betaOne The exponential decay rate for the 1st moment estimates. Defaults to 0.9. - * @param betaTwo The exponential decay rate for the 2nd moment estimates. Defaults to 0.999. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm - * 1 of the paper. Defaults to 1e-8. + * 1 of the paper. */ public Adam( Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { @@ -133,6 +186,32 @@ public Adam( this.epsilon = epsilon; } + /** + * Creates an Adam optimizer + * + * @param graph the TensorFlow graph + * @param name the Optimizer name. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the 2nd moment estimates. + * @param epsilon A small constant for numerical stability. This epsilon is "epsilon hat" in the + * Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm + * 1 of the paper. + */ + public Adam( + Graph graph, + String name, + Operand learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + super(graph, name, learningRateOperand); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + /** * Creates the Operation that minimizes the loss * @@ -265,6 +344,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adam"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index c381013e97c..e33775db961 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -25,6 +25,7 @@ */ public class Adamax extends Optimizer { + public static final String DEFAULT_NAME = "Adamax"; public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; @@ -43,7 +44,10 @@ public class Adamax extends Optimizer { private Variable betaOnePower; /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the + * Optimizer name, {@link #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} + * for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link + * #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph */ @@ -52,17 +56,21 @@ public Adamax(Graph graph) { } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for + * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} + * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param name name for the operations Created when applying gradients. Defaults to "Adamax". + * @param name name for the operations Created when applying gradients. */ public Adamax(Graph graph, String name) { this(graph, name, LEARNING_RATE_DEFAULT, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the + * Optimizer name, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for + * the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph * @param learningRate The learning rate. @@ -72,18 +80,48 @@ public Adamax(Graph graph, float learningRate) { } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, using {@link #DEFAULT_NAME} for the + * Optimizer name, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for + * the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adamax(Graph graph, Operand learningRateOperand) { + this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates an Optimizer that implements the Adamax algorithm, using {@link #BETA_ONE_DEFAULT} for + * the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link + * #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param name name for the operations Created when applying gradients. Defaults to "Adamax". + * @param name name for the operations Created when applying gradients. * @param learningRate The learning rate. */ public Adamax(Graph graph, String name, float learningRate) { this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates an Optimizer that implements the Adamax algorithm, using {@link #BETA_ONE_DEFAULT} for + * the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link + * #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name name for the operations Created when applying gradients. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Adamax(Graph graph, String name, Operand learningRateOperand) { + this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } /** - * Creates an Optimizer that implements the Adamax algorithm. + * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for + * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} + * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph * @param learningRate The learning rate. @@ -94,12 +132,32 @@ public Adamax(Graph graph, String name, float learningRate) { public Adamax(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { this(graph, null, learningRate, betaOne, betaTwo, epsilon); } + /** + * Creates an Optimizer that implements the Adamax algorithm, {@link #LEARNING_RATE_DEFAULT} for + * the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} + * for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Adamax( + Graph graph, + Operand learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon); + } /** * Creates an Optimizer that implements the Adamax algorithm. * * @param graph the TensorFlow graph - * @param name name for the operations Created when applying gradients. Defaults to "Adamax". + * @param name name for the operations Created when applying gradients. * @param learningRate The learning rate. * @param betaOne The exponential decay rate for the 1st moment estimates. * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. @@ -113,6 +171,30 @@ public Adamax( this.epsilon = epsilon; } + /** + * Creates an Optimizer that implements the Adamax algorithm. + * + * @param graph the TensorFlow graph + * @param name name for the operations Created when applying gradients. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Adamax( + Graph graph, + String name, + Operand learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + super(graph, name, learningRateOperand); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + /** {@inheritDoc} */ @Override protected Optional prepare(String scopeName) { @@ -177,6 +259,6 @@ protected Op finish(List updateOperations, String name) { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Adamax"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index edbe91c62e9..35eeb7dc225 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -6,6 +6,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyFtrl; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -20,6 +21,8 @@ */ public class Ftrl extends Optimizer { + public static final String DEFAULT_NAME = "Ftrl"; + public static final String ACCUMULATOR = "gradient_accumulator"; public static final String LINEAR_ACCUMULATOR = "linear_accumulator"; @@ -37,7 +40,12 @@ public class Ftrl extends Optimizer { private final float l2ShrinkageRegularizationStrength; /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #LEARNING_RATE_POWER_DEFAULT} for the + * learningRatePower. {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, + * {@link #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength + * and {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph */ @@ -53,7 +61,12 @@ public Ftrl(Graph graph) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_DEFAULT} for the learning rate, {@link + * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link + * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph * @param name the name of this Optimizer @@ -71,7 +84,12 @@ public Ftrl(Graph graph, String name) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link + * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -88,7 +106,34 @@ public Ftrl(Graph graph, float learningRate) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. {@link + * #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Ftrl(Graph graph, Operand learningRateOperand) { + this( + graph, + learningRateOperand, + LEARNING_RATE_POWER_DEFAULT, + INITIAL_ACCUMULATOR_VALUE_DEFAULT, + L1STRENGTH_DEFAULT, + L2STRENGTH_DEFAULT, + L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT); + } + + /** + * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. + * {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. * * @param graph the TensorFlow Graph * @param name the name of this Optimizer @@ -107,7 +152,31 @@ public Ftrl(Graph graph, String name, float learningRate) { } /** - * Creates a Ftrl Optimizer + * Creates an Ftrl Optimizer using {@link #LEARNING_RATE_POWER_DEFAULT} for the learningRatePower. + * {@link #INITIAL_ACCUMULATOR_VALUE_DEFAULT} for the initialAccumulatorValue, {@link + * #L1STRENGTH_DEFAULT} for the l1Strength, {@link #L2STRENGTH_DEFAULT} for the l2Strength and + * {@link #L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT} for the + * l2ShrinkageRegularizationStrength. + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Ftrl(Graph graph, String name, Operand learningRateOperand) { + this( + graph, + name, + learningRateOperand, + LEARNING_RATE_POWER_DEFAULT, + INITIAL_ACCUMULATOR_VALUE_DEFAULT, + L1STRENGTH_DEFAULT, + L2STRENGTH_DEFAULT, + L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT); + } + + /** + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -143,6 +212,44 @@ public Ftrl( l2ShrinkageRegularizationStrength); } + /** + * Creates an Ftrl Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param learningRatePower Controls how the learning rate decreases during training. Use zero for + * a fixed learning rate. + * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive + * values are allowed. + * @param l1Strength the L1 Regularization strength, must be greater than or equal to zero. + * @param l2Strength the L2 Regularization strength, must be greater than or equal to zero. + * @param l2ShrinkageRegularizationStrength This differs from L2 above in that the L2 above is a + * stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. must be greater + * than or equal to zero. + * @throws java.lang.IllegalArgumentException if the initialAccumulatorValue, + * l1RegularizationStrength, l2RegularizationStrength, or l2ShrinkageRegularizationStrength + * are less than 0.0, or learningRatePower is greater than 0.0. + */ + public Ftrl( + Graph graph, + Operand learningRateOperand, + float learningRatePower, + float initialAccumulatorValue, + float l1Strength, + float l2Strength, + float l2ShrinkageRegularizationStrength) { + this( + graph, + null, + learningRateOperand, + learningRatePower, + initialAccumulatorValue, + l1Strength, + l2Strength, + l2ShrinkageRegularizationStrength); + } + /** * Creates a Ftrl Optimizer * @@ -180,7 +287,45 @@ public Ftrl( validateParams(); } - /** Validates all the settings of the Frtl Optmizer */ + /** + * Creates a Ftrl Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param learningRatePower Controls how the learning rate decreases during training. Use zero for + * a fixed learning rate. + * @param initialAccumulatorValue The starting value for accumulators. Only zero or positive + * values are allowed. + * @param l1Strength the L1 Regularization strength, must be greater than or equal to zero. + * @param l2Strength the L2 Regularization strength, must be greater than or equal to zero. + * @param l2ShrinkageRegularizationStrength This differs from L2 above in that the L2 above is a + * stabilization penalty, whereas this L2 shrinkage is a magnitude penalty. must be greater + * than or equal to zero. + * @throws java.lang.IllegalArgumentException if the initialAccumulatorValue, + * l1RegularizationStrength, l2RegularizationStrength, or l2ShrinkageRegularizationStrength + * are less than 0.0, or learningRatePower is greater than 0.0. + */ + public Ftrl( + Graph graph, + String name, + Operand learningRateOperand, + float learningRatePower, + float initialAccumulatorValue, + float l1Strength, + float l2Strength, + float l2ShrinkageRegularizationStrength) { + super(graph, name, learningRateOperand); + this.learningRatePower = learningRatePower; + this.initialAccumulatorValue = initialAccumulatorValue; + this.l1RegularizationStrength = l1Strength; + this.l2RegularizationStrength = l2Strength; + this.l2ShrinkageRegularizationStrength = l2ShrinkageRegularizationStrength; + validateParams(); + } + + /** Validates all the settings of the Ftrl Optimizer */ private void validateParams() { if (this.initialAccumulatorValue < 0.0F) { throw new IllegalArgumentException( @@ -257,6 +402,6 @@ protected Op applyDense(Output gradient, Output variable /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Ftrl"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index f57503d3347..efec399f40d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -16,8 +16,10 @@ package org.tensorflow.framework.optimizers; import org.tensorflow.Graph; +import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; /** @@ -26,10 +28,12 @@ */ public class GradientDescent extends Optimizer { + public static final String DEFAULT_NAME = "GradientDescent"; public static final float LEARNING_RATE_DEFAULT = 0.01f; /** - * Creates a GradientDescent Optimizer + * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and + * {@link #LEARNING_RATE_DEFAULT} for the learning rate. * * @param graph the TensorFlow graph */ @@ -38,26 +42,49 @@ public GradientDescent(Graph graph) { } /** - * Creates a GradientDescent Optimizer + * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph - * @param learningRate the learning rate, defaults to 0.01 + * @param learningRate the learning rate. */ public GradientDescent(Graph graph, float learningRate) { super(graph, null, learningRate); } + /** + * Creates a GradientDescent Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public GradientDescent(Graph graph, Operand learningRateOperand) { + super(graph, null, learningRateOperand); + } + /** * Creates a GradientDescent Optimizer * * @param graph the TensorFlow graph - * @param name the name for this Optimizer, default is "GradientDescent" - * @param learningRate the learning rate, defaults to 0.01 + * @param name the name for this Optimizer. + * @param learningRate the learning rate. */ public GradientDescent(Graph graph, String name, float learningRate) { super(graph, name, learningRate); } + /** + * Creates a GradientDescent Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public GradientDescent(Graph graph, String name, Operand learningRateOperand) { + super(graph, name, learningRateOperand); + } + /** {@inheritDoc} */ @Override protected Op applyDense(Output gradient, Output variable) { @@ -74,6 +101,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "GradientDescent"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index 19e3f275f1f..436b587c353 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -21,6 +21,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; import org.tensorflow.op.train.ApplyMomentum; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -33,6 +34,7 @@ */ public class Momentum extends Optimizer { + public static final String DEFAULT_NAME = "Momentum"; public static final float LEARNING_RATE_DEFAULT = 0.01F; public static final float MOMENTUM_DEFAULT = 0.0F; public static final boolean NESTEROV_DEFAULT = false; @@ -44,7 +46,9 @@ public class Momentum extends Optimizer { private final boolean useNesterov; /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #MOMENTUM_DEFAULT} for the momentum, and + * {@link #NESTEROV_DEFAULT} for the Nesterov flag. * * @param graph the TensorFlow graph */ @@ -53,7 +57,8 @@ public Momentum(Graph graph) { } /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #MOMENTUM_DEFAULT} for the momentum, and {@link #NESTEROV_DEFAULT} for the Nesterov flag. * * @param graph the TensorFlow graph * @param learningRate the learning rate @@ -63,30 +68,76 @@ public Momentum(Graph graph, float learningRate) { } /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #MOMENTUM_DEFAULT} for the momentum, and {@link #NESTEROV_DEFAULT} for the Nesterov flag. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Momentum(Graph graph, Operand learningRateOperand) { + this(graph, learningRateOperand, MOMENTUM_DEFAULT, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and {@link + * #NESTEROV_DEFAULT} for the Nesterov flag. * * @param graph the TensorFlow graph * @param learningRate the learning rate * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and - * dampens oscillations, Must be greater than or equal to zero. Default is 0. + * dampens oscillations, Must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. */ public Momentum(Graph graph, float learningRate, float momentum) { this(graph, learningRate, momentum, NESTEROV_DEFAULT); } /** - * Creates a Momentum Optimizer + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name and {@link + * #NESTEROV_DEFAULT} for the Nesterov flag. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. + */ + public Momentum(Graph graph, Operand learningRateOperand, float momentum) { + this(graph, learningRateOperand, momentum, NESTEROV_DEFAULT); + } + + /** + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph * @param learningRate the learning rate * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and - * dampens oscillations, Must be greater than or equal to zero. Default is 0. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. */ public Momentum(Graph graph, float learningRate, float momentum, boolean useNesterov) { this(graph, null, learningRate, momentum, useNesterov); } + /** + * Creates a Momentum Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. + */ + public Momentum( + Graph graph, Operand learningRateOperand, float momentum, boolean useNesterov) { + this(graph, null, learningRateOperand, momentum, useNesterov); + } + /** * Creates a Momentum Optimizer * @@ -94,12 +145,40 @@ public Momentum(Graph graph, float learningRate, float momentum, boolean useNest * @param name the name for this Optimizer * @param learningRate the learning rate * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and - * dampens oscillations, Must be greater than or equal to zero. Default is 0. - * @param useNesterov Whether to apply Nesterov momentum. Defaults to false. + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. */ public Momentum( Graph graph, String name, float learningRate, float momentum, boolean useNesterov) { super(graph, name, learningRate); + if (momentum < 0) + throw new IllegalArgumentException("momentum must be greater than or equal to zero."); + this.momentum = momentum; + this.useNesterov = useNesterov; + } + + /** + * Creates a Momentum Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param momentum hyperparameter that accelerates gradient descent in the relevant direction and + * dampens oscillations, Must be greater than or equal to zero. + * @param useNesterov Whether to apply Nesterov momentum. + * @throws java.lang.IllegalArgumentException if momentum is less than zero. + */ + public Momentum( + Graph graph, + String name, + Operand learningRateOperand, + float momentum, + boolean useNesterov) { + super(graph, name, learningRateOperand); + if (momentum < 0) + throw new IllegalArgumentException("momentum must be greater than or equal to zero."); this.momentum = momentum; this.useNesterov = useNesterov; } @@ -152,6 +231,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Momentum"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index ece7c024969..202f5013e7d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -25,6 +25,7 @@ */ public class Nadam extends Optimizer { + public static final String DEFAULT_NAME = "Nadam"; private static final float DECAY_BASE = 0.96f; private static final float DECAY = 0.004f; public static final float LEARNING_RATE_DEFAULT = 0.001f; @@ -65,7 +66,9 @@ public class Nadam extends Optimizer { private Operand vTPrimeDenominator; /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #BETA_ONE_DEFAULT} for the betaOne value, + * {@link #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph */ @@ -74,50 +77,96 @@ public Nadam(Graph graph) { } /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param learningRate the learning rate, defaults to 0.001 + * @param learningRate the learning rate. */ public Nadam(Graph graph, float learningRate) { this(graph, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #BETA_ONE_DEFAULT} for the betaOne value, {@link #BETA_TWO_DEFAULT} for the betaTwo value, and + * {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Nadam(Graph graph, Operand learningRateOperand) { + this(graph, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + + /** + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow graph - * @param learningRate the learning rate, defaults to 0.001 - * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default - * is 0.999. - * @param epsilon A small constant for numerical stability. Default is 1e-8. + * @param learningRate the learning rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. */ public Nadam(Graph graph, float learningRate, float betaOne, float betaTwo, float epsilon) { this(graph, null, learningRate, betaOne, betaTwo, epsilon); } /** - * Creates a Nadam Optimizer + * Creates a Nadam Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Nadam( + Graph graph, + Operand learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + this(graph, null, learningRateOperand, betaOne, betaTwo, epsilon); + } + + /** + * Creates a Nadam Optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. * * @param graph the TensorFlow graph - * @param name the name for this Optimizer, defaults to "Nadam" - * @param learningRate the learning rate, defaults to 0.001 + * @param name the name for this Optimizer. + * @param learningRate the learning rate. */ public Nadam(Graph graph, String name, float learningRate) { this(graph, name, learningRate, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); } + /** + * Creates a Nadam Optimizer using {@link #BETA_ONE_DEFAULT} for the betaOne value, {@link + * #BETA_TWO_DEFAULT} for the betaTwo value, and {@link #EPSILON_DEFAULT} for the epsilon. + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public Nadam(Graph graph, String name, Operand learningRateOperand) { + this(graph, name, learningRateOperand, BETA_ONE_DEFAULT, BETA_TWO_DEFAULT, EPSILON_DEFAULT); + } + /** * Creates a Nadam Optimizer * * @param graph the TensorFlow graph - * @param name the name for this Optimizer, defaults to "Nadam" - * @param learningRate the learning rate, defaults to 0.001 - * @param betaOne The exponential decay rate for the 1st moment estimates. Default is 0.9. - * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. Default - * is 0.999. - * @param epsilon A small constant for numerical stability. Default is 1e-8. + * @param name the name for this Optimizer. + * @param learningRate the learning rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. */ public Nadam( Graph graph, String name, float learningRate, float betaOne, float betaTwo, float epsilon) { @@ -127,6 +176,30 @@ public Nadam( this.epsilon = epsilon; } + /** + * Creates a Nadam Optimizer + * + * @param graph the TensorFlow graph + * @param name the name for this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param betaOne The exponential decay rate for the 1st moment estimates. + * @param betaTwo The exponential decay rate for the exponentially weighted infinity norm. + * @param epsilon A small constant for numerical stability. + */ + public Nadam( + Graph graph, + String name, + Operand learningRateOperand, + float betaOne, + float betaTwo, + float epsilon) { + super(graph, name, learningRateOperand); + this.betaOne = betaOne; + this.betaTwo = betaTwo; + this.epsilon = epsilon; + } + /** {@inheritDoc} */ @Override protected void createSlots(List> variables) { @@ -287,6 +360,6 @@ protected Op finish(List updateOperations, String name) { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "Nadam"; + return DEFAULT_NAME; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index 5194cb32e73..586fef28c1e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -50,6 +50,7 @@ public abstract class Optimizer implements AutoCloseable { protected Placeholder learningRatePlaceholder = null; private Tensor learningRateTensor; private Map, Tensor> feedMap = null; + private Operand learningRateOperand; /** * Builds an optimizer for the supplied graph. @@ -66,6 +67,21 @@ protected Optimizer(Graph graph, String name, float learningRate) { setLearningRate(learningRate); } + /** + * Builds an optimizer for the supplied graph. + * + * @param graph The graph to optimize. + * @param name The base name for the operations. + * @param learningRateOperand the learning rate. + */ + protected Optimizer(Graph graph, String name, Operand learningRateOperand) { + this.graph = graph; + this.tf = Ops.create(graph).withName(name == null ? getOptimizerName() : name); + this.slots = new HashMap<>(); + this.globals = new ArrayList<>(); + setLearningRateOperand(learningRateOperand); + } + /** * Creates a name by combining a variable name and a slot name * @@ -293,6 +309,17 @@ public final void setLearningRate(float newLearningRate) { } } + /** + * Sets the learning rate Operand. The learning rate operand is an operand that is used to + * calculate the learning rate. + * + * @param newLearningRateOperand the new learning rate operand. + */ + public final void setLearningRateOperand(Operand newLearningRateOperand) { + close(); // Cleanup the placeholder and tensor if they exist. + learningRateOperand = newLearningRateOperand; + } + /** * Gets the learning rate * @@ -303,20 +330,23 @@ public float getLearningRate() { } /** - * Gets the learning rate Operand, used by subclasses in their graph operations + * Gets the learning rate Operand, used by subclasses in their graph operations. If a float + * learning rate has been set using {@link #setLearningRate}, then this will be the learning rate + * Placeholder, otherwise the learning rate operand is returned as passed to {@link + * #setLearningRateOperand}. * * @return the learning rate Operand */ protected Operand getLearningRateOperand() { - return learningRatePlaceholder; + return learningRatePlaceholder == null ? learningRateOperand : learningRatePlaceholder; } /** * Gets the Feed Map for the run methods to set the Placeholder value(s). Each entry in the Feed - * Map contains a PlaceHolder and a Tensor with the value + * Map contains a PlaceHolder and a Tensor with the value. * - * @return the current Feed Map for the run methods, this may be null if the LearningRate is an - * Operand has been set. + * @return the current Feed Map for the run methods, this will be null if the LearningRateOperand + * has been set. */ public Map, Tensor> getFeedMap() { return feedMap; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index 41b65a0ac01..7a03bec849d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; import java.util.List; @@ -47,6 +48,7 @@ */ public class RMSProp extends Optimizer { + public static final String DEFAULT_NAME = "RMSProp"; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float DECAY_DEFAULT = 0.9f; public static final float MOMENTUM_DEFAULT = 0.0f; @@ -62,7 +64,10 @@ public class RMSProp extends Optimizer { private final boolean centered; /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #LEARNING_RATE_DEFAULT} for the learning rate, {@link #DECAY_DEFAULT} for the decay, {@link + * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link + * #CENTERED_DEFAULT} for the centered flag. * * @param graph the TensorFlow Graph */ @@ -77,7 +82,9 @@ public RMSProp(Graph graph) { } /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #DECAY_DEFAULT} for the decay, {@link #MOMENTUM_DEFAULT} for the momentum, {@link + * #EPSILON_DEFAULT} for the epsilon value and {@link #CENTERED_DEFAULT} for the centered flag. * * @param graph the TensorFlow Graph * @param learningRate the learning rate @@ -87,17 +94,36 @@ public RMSProp(Graph graph, float learningRate) { } /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name, {@link + * #DECAY_DEFAULT} for the decay, {@link #MOMENTUM_DEFAULT} for the momentum, {@link + * #EPSILON_DEFAULT} for the epsilon value and {@link #CENTERED_DEFAULT} for the centered flag. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public RMSProp(Graph graph, Operand learningRateOperand) { + this( + graph, + learningRateOperand, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); + } + + /** + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. * * @param graph the TensorFlow Graph * @param learningRate the learning rate - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum the acceleration factor, default is 0. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum the acceleration factor. * @param epsilon A small constant for numerical stability * @param centered If true, gradients are normalized by the estimated variance of the * gradient; if false>, by the uncentered second moment. Setting this to * true> may help with training, but is slightly more expensive in terms of computation - * and memory. Defaults to false. + * and memory. */ public RMSProp( Graph graph, @@ -110,10 +136,36 @@ public RMSProp( } /** - * Creates an RMSPRrop Optimizer + * Creates an RMSPRrop Optimizer using {@link #DEFAULT_NAME} for the Optimizer name. + * + * @param graph the TensorFlow Graph + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum the acceleration factor. + * @param epsilon A small constant for numerical stability + * @param centered If true, gradients are normalized by the estimated variance of the + * gradient; if false>, by the uncentered second moment. Setting this to + * true> may help with training, but is slightly more expensive in terms of computation + * and memory. + */ + public RMSProp( + Graph graph, + Operand learningRateOperand, + float decay, + float momentum, + float epsilon, + boolean centered) { + this(graph, null, learningRateOperand, decay, momentum, epsilon, centered); + } + + /** + * Creates an RMSPRrop Optimizer using {@link #DECAY_DEFAULT} for the decay, {@link + * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link + * #CENTERED_DEFAULT} for the centered flag. * * @param graph the TensorFlow Graph - * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param name the name of this Optimizer. * @param learningRate the learning rate */ public RMSProp(Graph graph, String name, float learningRate) { @@ -127,19 +179,40 @@ public RMSProp(Graph graph, String name, float learningRate) { CENTERED_DEFAULT); } + /** + * Creates an RMSPRrop Optimizer using {@link #DECAY_DEFAULT} for the decay, {@link + * #MOMENTUM_DEFAULT} for the momentum, {@link #EPSILON_DEFAULT} for the epsilon value and {@link + * #CENTERED_DEFAULT} for the centered flag. + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + */ + public RMSProp(Graph graph, String name, Operand learningRateOperand) { + this( + graph, + name, + learningRateOperand, + DECAY_DEFAULT, + MOMENTUM_DEFAULT, + EPSILON_DEFAULT, + CENTERED_DEFAULT); + } + /** * Creates an RMSPRrop Optimizer * * @param graph the TensorFlow Graph - * @param name the name of this Optimizer. Defaults to "RMSProp". + * @param name the name of this Optimizer. * @param learningRate the learning rate - * @param decay Discounting factor for the history/coming gradient. Defaults to 0.9. - * @param momentum The acceleration factor, default is 0. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum The acceleration factor,. * @param epsilon A small constant for numerical stability * @param centered If true, gradients are normalized by the estimated variance of the * gradient; if false>, by the uncentered second moment. Setting this to * true> may help with training, but is slightly more expensive in terms of computation - * and memory. Defaults to false. + * and memory. */ public RMSProp( Graph graph, @@ -156,6 +229,36 @@ public RMSProp( this.centered = centered; } + /** + * Creates an RMSPRrop Optimizer + * + * @param graph the TensorFlow Graph + * @param name the name of this Optimizer. + * @param learningRateOperand the learning rate Operand, this is used to calculate the learning + * rate. + * @param decay Discounting factor for the history/coming gradient. + * @param momentum The acceleration factor. + * @param epsilon A small constant for numerical stability + * @param centered If true, gradients are normalized by the estimated variance of the + * gradient; if false>, by the uncentered second moment. Setting this to + * true> may help with training, but is slightly more expensive in terms of computation + * and memory. + */ + public RMSProp( + Graph graph, + String name, + Operand learningRateOperand, + float decay, + float momentum, + float epsilon, + boolean centered) { + super(graph, name, learningRateOperand); + this.decay = decay; + this.momentum = momentum; + this.epsilon = epsilon; + this.centered = centered; + } + /** {@inheritDoc} */ @Override protected void createSlots(List> variables) { @@ -233,6 +336,6 @@ public String toString() { /** {@inheritDoc} */ @Override public String getOptimizerName() { - return "RMSProp"; + return DEFAULT_NAME; } } From ca1395e45fc8db61c29d5fcabbde3c1109f19239 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 25 Sep 2020 14:45:55 -0400 Subject: [PATCH 14/14] Added Operand learningRateOperand test case for learning rate. --- .../framework/optimizers/AdaDeltaTest.java | 107 +++++++++++- .../framework/optimizers/AdaGradDATest.java | 51 +++++- .../framework/optimizers/AdaGradTest.java | 64 +++++++ .../framework/optimizers/AdamTest.java | 138 ++++++++++++++- .../framework/optimizers/AdamaxTest.java | 131 +++++++++++++-- .../framework/optimizers/FtrlTest.java | 63 +++++++ .../optimizers/GradientDescentTest.java | 55 +++++- .../framework/optimizers/MomentumTest.java | 56 +++++- .../framework/optimizers/NadamTest.java | 135 +++++++++++++++ .../framework/optimizers/RMSPropTest.java | 159 +++++++++++++++++- 10 files changed, 936 insertions(+), 23 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java index 3547ea9a30e..7653c99bc98 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaDeltaTest.java @@ -167,6 +167,109 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + int numUpdates = 4; // # number of ADADELTA steps to perform + float[] grads = {0.2F, 0.1F, 0.01F}; + float[] lrs = {1.0F, 0.5F, 0.1F}; + + float rho = 0.95F; + float epsilon = 1e-8F; + + for (float grad : grads) { + for (float lr : lrs) { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // this just uses a trivial operand + try (AdaDelta instance = + new AdaDelta( + session.getGraph(), + tf.math.mul(tf.constant(lr), tf.constant(1.f)), + rho, + epsilon)) { + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] fgrads = {grad, grad}; + Shape shape = Shape.of(var0Init.length); + Variable var0 = tf.withName("var0").variable(shape, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant cgrads = tf.constant(fgrads); + float accum = 0.0F; + float accumUpdate = 0.0F; + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(cgrads.asOutput(), var1.asOutput())); + + /*apply gradients */ + Op adadeltaUpdate = instance.applyGradients(gradsAndVars, "AdaDeltaTest"); + + /* Create and validate the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable[] slots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable[] slotUpdates = new Variable[2]; + + slots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[0].asOutput().shape(), var0.asOutput().shape()); + + slotUpdates[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[0].asOutput().shape(), var0.asOutput().shape()); + + slots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(slots[1].asOutput().shape(), var1.asOutput().shape()); + + slotUpdates[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR_UPDATE).get(); + assertEquals(slotUpdates[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + float[] updates = new float[numUpdates]; + float totUpdate = 0; + for (int step = 0; step < numUpdates; step++) { + + session.run(adadeltaUpdate, instance.getFeedMap()); + accum = accum * rho + (float) Math.pow(grad, 2) * (1.0F - rho); + updates[step] = + ((float) Math.sqrt(accumUpdate + epsilon) + * (float) (1 / Math.sqrt(accum + epsilon)) + * grad); + accumUpdate = + (accumUpdate * rho + ((float) Math.pow(updates[step], 2) * (1.0F - rho))); + totUpdate += updates[step] * lr; + + for (int i = 0; i < 2; i++) { + session.evaluate(accum, slots[i]); + session.evaluate(accumUpdate, slotUpdates[i]); + } + + Float[] var0InitUpdate = {var0Init[0] - totUpdate, var0Init[1] - totUpdate}; + Float[] var1InitUpdate = {var1Init[0] - totUpdate, var1Init[1] - totUpdate}; + + session.evaluate(var0InitUpdate, var0); + session.evaluate(var1InitUpdate, var1); + } + } + } + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 4; // # number of ADADELTA steps to perform @@ -224,10 +327,10 @@ public void testWithLearningRateDecay() { session.run(var0Initializer); session.run(var1Initializer); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); - /** make sure the variables were initialized properly */ + /* make sure the variables were initialized properly */ session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java index 1f8044c1168..9e67b4660df 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradDATest.java @@ -97,6 +97,56 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {0.0F, 0.0F}; + float[] var1Init = {0.0F, 0.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.02F}; + float learningRate = 1.5F; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (AdaGradDA instance = + new AdaGradDA( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(2.f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradDATest"); + + /* initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + session.run(adaUpdate, instance.getFeedMap()); + float[] expected0 = {-0.904534F, -1.603567F}; + session.evaluate(expected0, var0); + float[] expected1 = {-0.094821f, -0.189358f}; + session.evaluate(expected1, var1); + } + } + } + @Test public void testWithLearningRateDecay() { float[] var0Init = {0.0F, 0.0F}; @@ -104,7 +154,6 @@ public void testWithLearningRateDecay() { float[] grads0Init = {0.1F, 0.2F}; float[] grads1Init = {0.01F, 0.02F}; float epsilon = 1e-8F; - float epsilon1 = 1e-5F; int numSteps = 4; float learningRate = 3.0F; try (TestSession session = TestSession.createTestSession(tfMode); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java index 4a71fe59ba0..dc05b9b8f81 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdaGradTest.java @@ -111,6 +111,70 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + int numSteps = 3; + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float learningRate = 1.0F; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (AdaGrad instance = + new AdaGrad( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(3.f)), 0.1f)) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op adaUpdate = instance.applyGradients(gradsAndVars, "AdGradTest"); + + @SuppressWarnings("unchecked") + Variable[] accumulatorSlots = new Variable[2]; + accumulatorSlots[0] = instance.getSlot(var0.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[0].asOutput().shape(), var0.asOutput().shape()); + + accumulatorSlots[1] = instance.getSlot(var1.asOutput(), ACCUMULATOR).get(); + assertEquals(accumulatorSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + for (int step = 0; step < numSteps; step++) { + session.run(adaUpdate, instance.getFeedMap()); + } + float[] expected0 = {-1.6026098728179932f, -0.6026098728179932f}; + session.evaluate(expected0, var0); + float[] expected1 = {2.715679168701172f, 3.715679168701172f}; + session.evaluate(expected1, var1); + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 3; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index a8be65c3650..c5bb153d804 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -65,8 +65,8 @@ public void testBasic() { FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); - float epsilon1 = 1e-3F; - float learningRate = 0.001F; + float epsilon1 = 1e-3f; + float learningRate = 0.001f; try (TestSession session = TestSession.createTestSession(tfMode); Adam instance = new Adam(session.getGraph(), learningRate)) { @@ -185,6 +185,140 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3f; + float learningRate = 0.001f; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Adam instance = + new Adam( + session.getGraph(), tf.constant(learningRate))) { + + float beta1 = 0.9F; + float beta2 = 0.999F; + + session.setEpsilon(epsilon1); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validate the shapes of the slots */ + @SuppressWarnings("unchecked") + Variable[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the accumulators */ + session.run(tf.init(), instance.getFeedMap()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + FloatNdArray m0Np = NdArrays.ofFloats(shape0); + FloatNdArray v0Np = NdArrays.ofFloats(shape0); + FloatNdArray m1Np = NdArrays.ofFloats(shape1); + FloatNdArray v1Np = NdArrays.ofFloats(shape1); + + for (int step = 0; step < 3; step++) { + + // Test powers + final float[] powers = { + (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) + }; + + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); + } + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("beta2_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + } + session.run(update, instance.getFeedMap()); + + float lrT = + learningRate + * (float) Math.sqrt(1 - (float) Math.pow(beta2, (step + 1))) + / (1 - (float) Math.pow(beta1, (step + 1))); + + m0Np = calculateM(m0Np, grads0Np, beta1); + v0Np = calculateV(v0Np, grads0Np, beta2); + var0Np = calculateParam(var0Np, lrT, m0Np, v0Np, 1e-7F); + + m1Np = calculateM(m1Np, grads1Np, beta1); + v1Np = calculateV(v1Np, grads1Np, beta2); + var1Np = calculateParam(var1Np, lrT, m1Np, v1Np, 1e-7F); + + // evaluate var 0 and var1 + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + + // first moment + session.evaluate(m0Np, firstMomentSlots[0]); + session.evaluate(m1Np, firstMomentSlots[1]); + + // second moment + session.evaluate(v0Np, secondMomentSlots[0]); + session.evaluate(v1Np, secondMomentSlots[1]); + } + } + } + } + @Test public void testWithLearningRateDecay() { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index 57d3cbdb70c..e900018ccad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.optimizers; import org.junit.jupiter.api.*; -import org.tensorflow.Graph; import org.tensorflow.Tensor; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; @@ -61,10 +60,10 @@ public void tearDown() {} /** Test of getOptimizerName method, of class Adamax. */ @Test public void testGetOptimizerName() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Graph graph = session.getGraph(); - Adamax instance = new Adamax(graph); - String expResult = "Adamax"; + try (TestSession session = TestSession.createTestSession(tfMode); + Adamax instance = new Adamax(session.getGraph())) { + + String expResult = DEFAULT_NAME; String result = instance.getOptimizerName(); assertEquals(expResult, result); } @@ -178,6 +177,118 @@ public void testBasic() { } } + /** Test of applyDense method, of class Adamax. */ + @Test + public void testBasicWithLROperand() { + + int numSteps = 3; + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float[] zeros = {0.0F, 0.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3f; + float learningRate = 1f; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Adamax instance = + new Adamax( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(1e-3f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + session.setEpsilon(epsilon1); + for (int step = 0; step < numSteps; step++) { + // Test powers + final float beta1Power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); + + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); + } + session.run(update, instance.getFeedMap()); + + FloatNdArray[] resultNP = calculate(var0Np, grads0Np, step, m0, v0); + var0Np = resultNP[VAR]; + m0 = resultNP[M]; + v0 = resultNP[V]; + + resultNP = calculate(var1Np, grads1Np, step, m1, v1); + var1Np = resultNP[VAR]; + m1 = resultNP[M]; + v1 = resultNP[V]; + + // evaluate var0 and var1 + + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + @Test public void testWithLearningRateDecay() { @@ -242,7 +353,7 @@ public void testWithLearningRateDecay() { secondMomentSlots[1] = instance.getSlot(var1.asOutput(), SECOND_MOMENT).get(); assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); - /** initialize the accumulators */ + /* initialize the accumulators */ session.run(tf.init()); /* initialize the local variables */ @@ -260,13 +371,7 @@ public void testWithLearningRateDecay() { .run() .get(0) .expect(TFloat32.DTYPE)) { - result - .data() - .scalars() - .forEach( - f -> { - assertEquals(betaPower, f.getFloat(), epsilon1); - }); + result.data().scalars().forEach(f -> assertEquals(betaPower, f.getFloat(), epsilon1)); } assertEquals(learningRate, instance.getLearningRate(), epsilon); session.evaluate( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java index 92b610e5951..c047cd4bf40 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/FtrlTest.java @@ -123,6 +123,69 @@ public void testFtrlWithL1L2L2Shrinkage() { } } + @Test + public void testFtrlWithL1L2L2ShrinkageWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {4.0F, 3.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.02F}; + float learningRate = 1.0F; + + int numSteps = 10; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Ftrl instance = + new Ftrl( + session.getGraph(), + tf.math.mul(tf.constant(learningRate), tf.constant(3f)), + -0.5F, // learningRatePower + 0.1F, // initialAccumulatorValue + 0.001F, // l1RegularizationStrength + 2.0F, // l2RegularizationStrength + 0.1F // l2ShrinkageRegularizationStrength + )) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op ftrlUpdate = instance.applyGradients(gradsAndVars, "FtrlTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + for (int i = 0; i < numSteps; i++) { + session.run(ftrlUpdate, instance.getFeedMap()); + } + + float[] expectedVar0 = {-0.22578995F, -0.44345796F}; + session.evaluate(expectedVar0, var0); + float[] expectedVar1 = {-0.14378493F, -0.13229476F}; + session.evaluate(expectedVar1, var1); + } + } + } + @Test public void testFtrlWithL1() { float[] var0Init = {1.0F, 2.0F}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 8e793e35d5f..ce687186994 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -56,7 +56,7 @@ public void testBasic() { float learningRate = 3.0F; try (TestSession session = TestSession.createTestSession(tfMode); - GradientDescent instance = new GradientDescent(session.getGraph(), learningRate); ) { + GradientDescent instance = new GradientDescent(session.getGraph(), learningRate)) { Ops tf = session.getTF(); Shape shape0 = Shape.of(var0Init.length); @@ -97,6 +97,59 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + float learningRate = 1.5f; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (GradientDescent instance = + new GradientDescent( + session.getGraph(), tf.math.mul(tf.constant(learningRate), tf.constant(2f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + session.run(update, instance.getFeedMap()); // 1 step + + float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; + float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; + session.evaluate(expectedVar0, var0); + session.evaluate(expectedVar1, var1); + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 2; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java index b54e3b52a26..014c72e55e6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/MomentumTest.java @@ -71,9 +71,9 @@ public void testBasic() { float[] grads1Init = {0.01F, 0.01F}; float learningRate = 3.0F; - try (TestSession session = TestSession.createTestSession(tfMode)) { + try (TestSession session = TestSession.createTestSession(tfMode); + Momentum instance = new Momentum(session.getGraph(), learningRate)) { Ops tf = session.getTF(); - Graph graph = session.getGraph(); Shape shape0 = Shape.of(var0Init.length); Shape shape1 = Shape.of(var1Init.length); @@ -91,7 +91,6 @@ public void testBasic() { gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); - Momentum instance = new Momentum(graph, learningRate); Op update = instance.applyGradients(gradsAndVars, "SGDTest"); /* initialize the local variables */ @@ -114,6 +113,57 @@ public void testBasic() { } } + @Test + public void testBasicWithLROperand() { + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + float learningRate = 3.0F; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Momentum instance = new Momentum(session.getGraph(), tf.constant(learningRate))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "SGDTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + session.run(update, instance.getFeedMap()); // 1 step + + float[] expectedVar0 = {1.0F - 3.0F * 0.1F, 2.0F - 3.0F * 0.1F}; + float[] expectedVar1 = {3.0F - 3.0F * 0.01F, 4.0F - 3.0F * 0.01F}; + session.evaluate(expectedVar0, var0); + session.evaluate(expectedVar1, var1); + } + } + } + @Test public void testMomentum() { float[] var0Init = {1.0F, 2.0F}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index c7c17689a33..0832543c104 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -203,6 +203,141 @@ public void testBasic() { } } + /** Test of applyDense method, of class Nadam. */ + @Test + public void testBasicWithLROperand() { + + int numSteps = 3; + + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.1F}; + float[] grads1Init = {0.01F, 0.01F}; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; + FloatNdArray m0 = NdArrays.vectorOf(zeros); + FloatNdArray v0 = NdArrays.vectorOf(zeros); + FloatNdArray m1 = NdArrays.vectorOf(zeros); + FloatNdArray v1 = NdArrays.vectorOf(zeros); + FloatNdArray mcache = NdArrays.vectorOf(ones); + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + float epsilon1 = 1e-3F; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + try (Nadam instance = + new Nadam(session.getGraph(), tf.math.mul(tf.constant(1f), tf.constant(1e-3f)))) { + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new Optimizer.GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "AdamTest"); + + /* Create and validae the shapes of the slota */ + @SuppressWarnings("unchecked") + Variable[] firstMomentSlots = new Variable[2]; + @SuppressWarnings("unchecked") + Variable[] secondMomentSlots = new Variable[2]; + + firstMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + secondMomentSlots[0] = instance.getSlot(var0.asOutput(), Nadam.SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[0].asOutput().shape(), var0.asOutput().shape()); + + firstMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.FIRST_MOMENT).get(); + assertEquals(firstMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + secondMomentSlots[1] = instance.getSlot(var1.asOutput(), Nadam.SECOND_MOMENT).get(); + assertEquals(secondMomentSlots[1].asOutput().shape(), var1.asOutput().shape()); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + session.setEpsilon(epsilon1); + + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); + } + momentum = 1F; + + for (int step = 0; step < numSteps; step++) { + + session.run(update, instance.getFeedMap()); + + float mut = + Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); + momentum = momentum * mut; + + try (Tensor result = + session + .getGraphSession() + .runner() + .fetch("momentum") + .run() + .get(0) + .expect(TFloat32.DTYPE)) { + result.data().scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); + } + mcache = ND.mul(mcache, momentum); + FloatNdArray[] resultsNP = nadamUpdateNdArray(var0Np, grads0Np, step, m0, v0, mcache); + var0Np = resultsNP[VAR]; + m0 = resultsNP[M]; + v0 = resultsNP[V]; + + resultsNP = nadamUpdateNdArray(var1Np, grads1Np, step, m1, v1, mcache); + var1Np = resultsNP[VAR]; + m1 = resultsNP[M]; + v1 = resultsNP[V]; + + // evaluate m0 and m1 + session.evaluate(m0, firstMomentSlots[0]); + session.evaluate(m1, firstMomentSlots[1]); + + // evaluate v0 and v1 + session.evaluate(v0, secondMomentSlots[0]); + session.evaluate(v1, secondMomentSlots[1]); + + // evaluate var0 and var1 + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 3; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java index 2a012ff0f99..711358e9222 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/RMSPropTest.java @@ -73,7 +73,7 @@ public void testDense() { for (Object[] testParamValue : testParamValues) { // learningRate, rho (decay), momentum, epsilon, centered - float learningRate = (float) (float) testParamValue[0]; + float learningRate = (float) testParamValue[0]; float decay = (float) testParamValue[1]; float momentum = (float) testParamValue[2]; float epsilon = (float) testParamValue[3]; @@ -215,6 +215,163 @@ public void testDense() { } } + @Test + public void testDenseWithLROperand() { + + int numSteps = 3; + + for (Object[] testParamValue : testParamValues) { + // learningRate, rho (decay), momentum, epsilon, centered + float learningRate = (float) testParamValue[0]; + float decay = (float) testParamValue[1]; + float momentum = (float) testParamValue[2]; + float epsilon = (float) testParamValue[3]; + boolean centered = (boolean) testParamValue[4]; + try (TestSession session = TestSession.createTestSession(tfMode)) { + + Ops tf = session.getTF(); + try (RMSProp instance = + new RMSProp( + session.getGraph(), + tf.math.add(tf.constant(learningRate), tf.constant(0f)), + decay, + momentum, + epsilon, + centered)) { + + session.setEpsilon(1e-2f); + float[] var0Init = {1.0F, 2.0F}; + float[] var1Init = {3.0F, 4.0F}; + float[] grads0Init = {0.1F, 0.2F}; + float[] grads1Init = {0.01F, 0.2F}; + + FloatNdArray var0Np = NdArrays.vectorOf(var0Init); + FloatNdArray var1Np = NdArrays.vectorOf(var1Init); + FloatNdArray grads0Np = NdArrays.vectorOf(grads0Init); + FloatNdArray grads1Np = NdArrays.vectorOf(grads1Init); + + Shape shape0 = Shape.of(var0Init.length); + Shape shape1 = Shape.of(var1Init.length); + Variable var0 = tf.withName("var0").variable(shape0, TFloat32.DTYPE); + Variable var1 = tf.withName("var1").variable(shape1, TFloat32.DTYPE); + + Assign var0Initializer = tf.assign(var0, tf.constant(var0Init)); + Assign var1Initializer = tf.assign(var1, tf.constant(var1Init)); + + Constant grads0 = tf.constant(grads0Init); + Constant grads1 = tf.constant(grads1Init); + + /* build the GradsAnvVars */ + List> gradsAndVars = new ArrayList<>(); + gradsAndVars.add(new GradAndVar<>(grads0.asOutput(), var0.asOutput())); + gradsAndVars.add(new GradAndVar<>(grads1.asOutput(), var1.asOutput())); + + Op update = instance.applyGradients(gradsAndVars, "RMSPropTest"); + + /* initialize the local variables */ + session.run(var0Initializer); + session.run(var1Initializer); + + /* initialize the accumulators */ + session.run(tf.init()); + + /* make sure the variables were initialized properly */ + session.evaluate(var0Init, var0); + session.evaluate(var1Init, var1); + + Variable mg0 = + centered && instance.getSlot(var0.asOutput(), MG).isPresent() + ? instance.getSlot(var0.asOutput(), MG).get() + : null; + Variable mg1 = + centered && instance.getSlot(var1.asOutput(), MG).isPresent() + ? instance.getSlot(var1.asOutput(), MG).get() + : null; + Variable mom0 = + momentum > 0.F && instance.getSlot(var0.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var0.asOutput(), MOMENTUM).get() + : null; + Variable mom1 = + momentum > 0.F && instance.getSlot(var1.asOutput(), MOMENTUM).isPresent() + ? instance.getSlot(var1.asOutput(), MOMENTUM).get() + : null; + Variable rms0 = + instance.getSlot(var0.asOutput(), RMS).isPresent() + ? instance.getSlot(var0.asOutput(), RMS).get() + : null; + Variable rms1 = + instance.getSlot(var1.asOutput(), RMS).isPresent() + ? instance.getSlot(var1.asOutput(), RMS).get() + : null; + + float[] zeros = {0.0F, 0.0F}; + float[] ones = {1.0F, 1.0F}; // temp to match RMSProp + FloatNdArray mg0Np = NdArrays.vectorOf(zeros); + FloatNdArray mg1Np = NdArrays.vectorOf(zeros); + FloatNdArray rms0Np = NdArrays.vectorOf(ones); + FloatNdArray rms1Np = NdArrays.vectorOf(ones); + FloatNdArray mom0Np = NdArrays.vectorOf(zeros); + FloatNdArray mom1Np = NdArrays.vectorOf(zeros); + + for (int i = 0; i < numSteps; i++) { + session.run(update, instance.getFeedMap()); + FloatNdArray[] result0 = + calc( + var0Np, + grads0Np, + mg0Np, + rms0Np, + mom0Np, + learningRate, + decay, + momentum, + epsilon, + centered); + var0Np = result0[VAR_T]; + mg0Np = result0[MG_T]; + rms0Np = result0[RMS_T]; + mom0Np = result0[MOM_T]; + + FloatNdArray[] result1 = + calc( + var1Np, + grads1Np, + mg1Np, + rms1Np, + mom1Np, + learningRate, + decay, + momentum, + epsilon, + centered); + + var1Np = result1[VAR_T]; + mg1Np = result1[MG_T]; + rms1Np = result1[RMS_T]; + mom1Np = result1[MOM_T]; + + if (centered) { + if (mg0 != null) session.evaluate(mg0Np, mg0); + if (mg1 != null) session.evaluate(mg1Np, mg1); + } + + if (mom0 != null) session.evaluate(mom0Np, mom0); + if (mom1 != null) session.evaluate(mom1Np, mom1); + + /* TODO the values returned from rms slot, do not match what I see in the python test */ + if (rms0 != null) session.evaluate(rms0Np, rms0); + else fail("rms0 is null"); + if (rms1 != null) session.evaluate(rms1Np, rms1); + else fail("rms1 is null"); + + session.evaluate(var0Np, var0); + session.evaluate(var1Np, var1); + } + } + } + } + } + @Test public void testWithLearningRateDecay() { int numSteps = 3;