From b29edfdba6691276e3de3b0dee0d209d3eb215f3 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 11:16:37 -0500 Subject: [PATCH 01/97] Simplify generic parameters across losses and metrics. --- .../framework/losses/BinaryCrossentropy.java | 5 +- .../losses/CategoricalCrossentropy.java | 5 +- .../framework/losses/CategoricalHinge.java | 4 +- .../framework/losses/CosineSimilarity.java | 67 +++++++++++++--- .../tensorflow/framework/losses/Hinge.java | 4 +- .../tensorflow/framework/losses/Huber.java | 4 +- .../framework/losses/KLDivergence.java | 4 +- .../tensorflow/framework/losses/LogCosh.java | 4 +- .../org/tensorflow/framework/losses/Loss.java | 8 +- .../tensorflow/framework/losses/Losses.java | 73 ++++++++--------- .../framework/losses/MeanAbsoluteError.java | 4 +- .../losses/MeanAbsolutePercentageError.java | 4 +- .../framework/losses/MeanSquaredError.java | 4 +- .../losses/MeanSquaredLogarithmicError.java | 4 +- .../tensorflow/framework/losses/Poisson.java | 4 +- .../losses/SparseCategoricalCrossentropy.java | 5 +- .../framework/losses/SquaredHinge.java | 5 +- .../framework/metrics/BinaryCrossentropy.java | 9 ++- .../metrics/CategoricalCrossentropy.java | 13 +-- .../framework/metrics/CategoricalHinge.java | 5 +- .../framework/metrics/CosineSimilarity.java | 8 +- .../tensorflow/framework/metrics/Hinge.java | 5 +- .../framework/metrics/KLDivergence.java | 5 +- .../framework/metrics/LogCoshError.java | 5 +- .../tensorflow/framework/metrics/Mean.java | 3 +- .../framework/metrics/MeanAbsoluteError.java | 5 +- .../metrics/MeanAbsolutePercentageError.java | 7 +- .../framework/metrics/MeanSquaredError.java | 5 +- .../metrics/MeanSquaredLogarithmicError.java | 7 +- .../tensorflow/framework/metrics/Metric.java | 32 ++++---- .../tensorflow/framework/metrics/Metrics.java | 80 +------------------ .../tensorflow/framework/metrics/Poisson.java | 6 +- .../SparseCategoricalCrossentropy.java | 9 +-- .../framework/metrics/SquaredHinge.java | 5 +- .../framework/metrics/impl/LossMetric.java | 3 +- .../metrics/impl/MeanMetricWrapper.java | 21 +++-- .../framework/metrics/impl/MetricsHelper.java | 66 +++++++-------- .../framework/metrics/impl/Reduce.java | 68 ++++++++-------- .../metrics/BinaryCrossentropyTest.java | 10 +-- .../metrics/CategoricalCrossentropyTest.java | 10 +-- .../metrics/CategoricalHingeTest.java | 4 +- .../metrics/CosineSimilarityTest.java | 6 +- .../framework/metrics/HingeTest.java | 4 +- .../framework/metrics/KLDivergenceTest.java | 4 +- .../framework/metrics/LogCoshErrorTest.java | 4 +- .../metrics/MeanAbsoluteErrorTest.java | 4 +- .../MeanAbsolutePercentageErrorTest.java | 4 +- .../metrics/MeanSquaredErrorTest.java | 4 +- .../MeanSquaredLogarithmicErrorTest.java | 4 +- .../framework/metrics/PoissonTest.java | 4 +- .../SparseCategoricalCrossentropyTest.java | 6 +- .../framework/metrics/SquaredHingeTest.java | 4 +- 52 files changed, 293 insertions(+), 354 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java index c7edfcca24e..3417c07372a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/BinaryCrossentropy.java @@ -202,13 +202,12 @@ public BinaryCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 363291fa5cc..035af9589ae 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -242,13 +242,12 @@ public CategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index f592c19f8bb..4e9133d8835 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -99,8 +99,8 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java index 137c7025c04..0a18d93caf3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CosineSimilarity.java @@ -22,12 +22,13 @@ /** * Computes the cosine similarity between labels and predictions. * - *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 - * indicates orthogonality and values closer to -1indicate greater similarity. The values closer to - * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you - * try to maximize the proximity between predictions and targets. If either labels or predictions is - * a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and - * targets. + *

Note that it is a number between -1 and 1. When it is a negative + * number between -1 and 0, 0 indicates orthogonality and + * values closer to -1indicate greater similarity. The values closer to 1 + * indicate greater dissimilarity. This makes it usable as a loss function in a setting where you + * try to maximize the proximity between predictions and targets. If either labels or + * predictions is a zero vector, cosine similarity will be 0 regardless of + * the proximity between predictions and targets. * *

loss = -sum(l2Norm(labels) * l2Norm(predictions)) * @@ -71,7 +72,7 @@ public class CosineSimilarity extends Loss { public static final int DEFAULT_AXIS = -1; public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO; - private final int axis; + private final int[] axis; /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, an axis @@ -107,6 +108,17 @@ public CosineSimilarity(Ops tf, int axis) { this(tf, null, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name, and a + * Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, int[] axis) { + + this(tf, null, axis, DEFAULT_REDUCTION); + } /** * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} @@ -120,6 +132,18 @@ public CosineSimilarity(Ops tf, String name, int axis) { this(tf, name, axis, DEFAULT_REDUCTION); } + /** + * Creates a Cosine Similarity Loss using a Loss Reduction of {@link #DEFAULT_REDUCTION} + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + */ + public CosineSimilarity(Ops tf, String name, int[] axis) { + + this(tf, name, axis, DEFAULT_REDUCTION); + } + /** * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name and an * axis of {@link #DEFAULT_AXIS} @@ -153,6 +177,18 @@ public CosineSimilarity(Ops tf, String name, Reduction reduction) { */ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { + this(tf, null, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss using {@link Class#getSimpleName()} as the loss name + * + * @param tf the TensorFlow Ops + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, int[] axis, Reduction reduction) { + this(tf, null, axis, reduction); } @@ -165,15 +201,28 @@ public CosineSimilarity(Ops tf, int axis, Reduction reduction) { * @param reduction Type of Reduction to apply to the loss. */ public CosineSimilarity(Ops tf, String name, int axis, Reduction reduction) { + this(tf, name, new int[] {axis}, reduction); + } + + /** + * Creates a Cosine Similarity Loss + * + * @param tf the TensorFlow Ops + * @param name the name of the loss + * @param axis The dimension along which the cosine similarity is computed. + * @param reduction Type of Reduction to apply to the loss. + */ + public CosineSimilarity(Ops tf, String name, int[] axis, Reduction reduction) { super(tf, name, reduction); this.axis = axis; } /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.cosineSimilarity(getTF(), labels, predictions, axis); + losses = tf.math.neg(losses); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 88b4a7aa056..37e7e367b9b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -121,8 +121,8 @@ public Hinge(Ops tf, String name, Reduction reduction) { * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") Operand tLabels = predictions.type() == labels.type() ? (Operand)labels : cast(tf, labels, predictions.type()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 6d3e3f0c2ac..e8de632eb09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -130,8 +130,8 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 8cf3db8d518..b3c0206b409 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -99,8 +99,8 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 1669669a768..812260d9881 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -105,8 +105,8 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index ae33d5dfa37..0f9b183f38c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -62,10 +62,9 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param labels the truth values or labels * @param predictions the predictions * @param The data type of the predictions and loss. - * @param The data type of the labels. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return call(labels, predictions, null); } @@ -82,11 +81,10 @@ public Operand call(Operand labels, * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss */ - public abstract Operand call( - Operand labels, Operand predictions, Operand sampleWeights); + public abstract Operand call( + Operand labels, Operand predictions, Operand sampleWeights); /** * Gets the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 0d25bd5e7e2..a5ced3d1df8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -48,11 +48,10 @@ public class Losses { * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute error */ - public static Operand meanAbsoluteError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsoluteError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -70,11 +69,10 @@ public static Operand meanAbsoluteErro * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared error */ - public static Operand meanSquaredError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredError( + Ops tf, Operand labels, Operand predictions) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -91,11 +89,10 @@ public static Operand meanSquaredError * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean absolute percentage error */ - public static Operand meanAbsolutePercentageError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanAbsolutePercentageError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -118,11 +115,10 @@ public static Operand meanAbsolutePerc * @param labels the labels * @param predictions the predictions * @param the data type of the predictions and result - * @param the data type of the labels * @return the mean squared logarithmic percentage error */ - public static Operand meanSquaredLogarithmicError( - Ops tf, Operand labels, Operand predictions) { + public static Operand meanSquaredLogarithmicError( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -152,8 +148,8 @@ public static Operand meanSquaredLogar * @param the data type of the predictions and labels * @return the binary crossentropy loss. */ - public static Operand binaryCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + public static Operand binaryCrossentropy( + Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -181,7 +177,7 @@ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); - /* TODO - skip this loggic for now. It requires walking back the inputs which is not yet possible + /* TODO - skip this logic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { // TODO - this does not work // TODO output = backtrackIdentity(output); @@ -225,9 +221,9 @@ private static Operand binaryCrossentropyHelper( * @param the data type of the predictions and labels * @return the categorical crossentropy loss. */ - public static Operand categoricalCrossentropy( + public static Operand categoricalCrossentropy( Ops tf, - Operand labels, + Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing, @@ -283,8 +279,8 @@ public static Operand categoricalCross * @param the data type of the predictions and labels * @return the categorical hinge loss */ - public static Operand categoricalHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand categoricalHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -329,8 +325,8 @@ public static Operand categoricalHinge * @param the data type of the predictions and labels * @return the cosine similarity loss */ - public static Operand cosineSimilarity( - Ops tf, Operand labels, Operand predictions, int axis) { + public static Operand cosineSimilarity( + Ops tf, Operand labels, Operand predictions, int[] axis) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); @@ -339,8 +335,7 @@ public static Operand cosineSimilarity tLabels = l2Normalize(tf, tLabels, axis); predictions = l2Normalize(tf, predictions, axis); Operand mathMul = tf.math.mul(tLabels, predictions); - Operand sum = tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); - return tf.math.neg(sum); + return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); } /** @@ -355,8 +350,8 @@ public static Operand cosineSimilarity * @param the data type of the predictions and labels * @return the hinge loss */ - public static Operand hinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand hinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -391,8 +386,8 @@ public static Operand hinge( * @param the data type of the predictions and labels * @return the Huber loss */ - public static Operand huber( - Ops tf, Operand labels, Operand predictions, float delta) { + public static Operand huber( + Ops tf, Operand labels, Operand predictions, float delta) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -422,8 +417,8 @@ public static Operand huber( * @see Kullback?Leibler * divergence */ - public static Operand kullbackLeiblerDivergence( - Ops tf, Operand labels, Operand predictions) { + public static Operand kullbackLeiblerDivergence( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -452,8 +447,8 @@ public static Operand kullbackLeiblerD * @param the data type of the predictions and labels * @return the hyperbolic cosine divergence loss */ - public static Operand logCosh( - Ops tf, Operand labels, Operand predictions) { + public static Operand logCosh( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -480,8 +475,8 @@ public static Operand logCosh( * @param the data type of the predictions and labels * @return the Poisson loss */ - public static Operand poisson( - Ops tf, Operand labels, Operand predictions) { + public static Operand poisson( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -507,8 +502,8 @@ public static Operand poisson( * @param the data type of the predictions and labels * @return the sparse categorical crossentropy loss */ - public static Operand sparseCategoricalCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + public static Operand sparseCategoricalCrossentropy( + Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -553,7 +548,7 @@ public static Operand sparseCategorica int labelsRank = labelsShape.numDimensions(); boolean updateShape = labelsRank != predictionsRank - 1; - if (updateShape) { // TODO check to see if this is right + if (updateShape) { Shape newShape = labelsShape.take(labelsRank - 1); iLabels = tf.reshape(iLabels, tf.constant(newShape)); // flatten one dimension predictions = @@ -584,8 +579,8 @@ public static Operand sparseCategorica * @param the data type of the predictions and labels * @return the squared hinge loss */ - public static Operand squaredHinge( - Ops tf, Operand labels, Operand predictions) { + public static Operand squaredHinge( + Ops tf, Operand labels, Operand predictions) { Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -651,7 +646,7 @@ private static Operand smoothCategoricalLabels( * @param axis Dimension along which to normalize. * @return the normalized values based on L2 norm */ - public static Operand l2Normalize(Ops tf, Operand x, int axis) { + public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index a2d5d5f8efc..594de1e1448 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -95,8 +95,8 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 49133df610b..275a2e136a0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -95,8 +95,8 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 2a6c2be885e..31df3e70e0b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -95,8 +95,8 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 2604e226b81..bef990d22bc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -95,8 +95,8 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index c43be4f2821..9cf38aa0380 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -104,8 +104,8 @@ public Poisson(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.poisson(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index ea765e6f8fd..3ec33113e89 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -190,13 +190,12 @@ public SparseCategoricalCrossentropy( * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { Operand lPredictions; if (!fromLogits) { // add predictions range check for 0 - 1 diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 4ad4c1c726c..968624db202 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -117,13 +117,12 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override - public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") Operand tLabels = predictions.type() == labels.type() ? (Operand)labels : cast(tf, labels, predictions.type()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 651a6fac0b0..abd2dcbbf40 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,17 +21,18 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * *

This is the crossentropy metric class to be used when there are only two label classes (0 and * 1). * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class BinaryCrossentropy + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -60,7 +61,7 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index c330ea88eaa..be43f34b92e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -30,11 +30,10 @@ * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] * . * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class CategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -48,7 +47,8 @@ public class CategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to + * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -68,7 +68,8 @@ public CategoricalCrossentropy( * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 * means that we will use a value of 0.1 for label 0 and 0.9 @@ -98,7 +99,7 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalCrossentropy( getTF(), labels, predictions, fromLogits, labelSmoothing, axis); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 2741a36edb6..c70f2d8643b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -24,10 +24,9 @@ /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class CategoricalHinge extends MeanMetricWrapper +public class CategoricalHinge< T extends TNumber> extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +45,7 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 458de092bec..5abbd095420 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.op.Ops; @@ -23,10 +24,9 @@ /** * A metric that computes the cosine similarity metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class CosineSimilarity extends MeanMetricWrapper +public class CosineSimilarity< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; @@ -76,8 +76,8 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity - return Metrics.cosineProximity(getTF(), labels, predictions, axis); + return Losses.cosineSimilarity(getTF(), labels, predictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index baf9ad8ab7d..e0aced6fa3e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -24,10 +24,9 @@ /** * A metric that computes the hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper +public class Hinge extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +45,7 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.hinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index efcbbcbb7f0..fa09f2784b5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -25,10 +25,9 @@ * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper +public class KLDivergence extends MeanMetricWrapper implements LossMetric { /** @@ -47,7 +46,7 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 3df8505d54b..c43551a6948 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -25,10 +25,9 @@ * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper +public class LogCoshError extends MeanMetricWrapper< T> implements LossMetric { /** @@ -47,7 +46,7 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.logCosh(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java index de1f5a5629e..8902b329bcc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -21,10 +21,9 @@ /** * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } * - * @param The data type for the metric values * @param The data type for the metric result */ -public class Mean extends Reduce { +public class Mean extends Reduce { /** * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index e27676932ff..d343ec77ab0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -24,10 +24,9 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper +public class MeanAbsoluteError extends MeanMetricWrapper< T> implements LossMetric { /** @@ -46,7 +45,7 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsoluteError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 84fa9b627b2..dd7d151260b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -24,11 +24,10 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { +public class MeanAbsolutePercentageError + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +45,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c7edd6ebe93..c2bef576b30 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -24,10 +24,9 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredError extends MeanMetricWrapper +public class MeanSquaredError< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { /** @@ -46,7 +45,7 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 199b6e0e114..c1cf4ca6c9a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -24,11 +24,10 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { +public class MeanSquaredLogarithmicError + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -46,7 +45,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index bbb2aa73da2..8ab21c58218 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -25,10 +25,9 @@ /** * Base class for Metrics * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Metric { +public abstract class Metric { /** The TensorFlow Ops */ private final Ops tf; @@ -75,10 +74,10 @@ protected Metric(Ops tf, String name, long seed) { * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return a List of Operations to update the metric state - * @param the data type for sampleWeights */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -90,13 +89,13 @@ public List updateStateList(Operand values, Operand the data type for the labels - * @param the data type for the sampleWeights * @return a List of Operations to update the metric state */ @SuppressWarnings({"unchecked", "unused"}) - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { return Collections.EMPTY_LIST; } @@ -105,10 +104,10 @@ public List updateStateList( * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for sampleWeights * @return the Operation to update the metric state */ - public final Op updateState(Operand values, Operand sampleWeights) { + public final Op updateState( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -119,12 +118,12 @@ public final Op updateState(Operand values, Operand sa * @param labels the labels * @param predictions the predictions * @param sampleWeights sample weights to be applied to values, may be null. - * @param the data type for the labels - * @param the data type for the sampleWeights * @return the Operation to update the metric state */ - public final Op updateState( - Operand labels, Operand predictions, Operand sampleWeights) { + public final Op updateState( + Operand labels, + Operand predictions, + Operand sampleWeights) { List controlOps = updateStateList(labels, predictions, sampleWeights); return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); } @@ -149,10 +148,9 @@ public final Op updateState( * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, may be null. * @return the result, possibly with control dependencies - * @param the data type for the sampleWeights. */ - public final Operand callOnce( - Operand values, Operand sampleWeights) { + public final Operand callOnce( + Operand values, Operand sampleWeights) { List controlOps = updateStateList(values, sampleWeights); Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); return ltf.identity(result()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 0169bc6b8bc..95b74bf1eea 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -17,7 +17,6 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TNumber; @@ -46,89 +45,14 @@ public class Metrics { * @param predictions The prediction values. * @param k Number of top elements to look at for computing accuracy. * @param the data type for the predictions and results - * @param the data type ofr the labels. * @return the Operand for the Top K categorical accuracy value. */ - public static Operand topKCategoricalAccuracy( - Ops tf, Operand labels, Operand predictions, long k) { + public static Operand topKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, long k) { Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); return CastHelper.cast( tf, tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); } - - /** - * Computes the cosine similarity between labels and predictions. - * - * @param tf the TensorFlow Ops - * @param labels The ground truth values. - * @param predictions The prediction values. - * @param axes The dimensions along which the cosine similarity is computed. - * @param the data type for the labels - * @param the data type for the predictions and result - * @return Cosine similarity value. - */ - public static Operand cosineProximity( - Ops tf, Operand labels, Operand predictions, int[] axes) { - Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); - labelsNorm = l2Normalize(tf, labelsNorm, axes); - - Operand predictionsNorm = l2Normalize(tf, predictions, axes); - Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); - return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE)); - } - - /** - * Normalizes along dimension axis using an L2 norm with an epsilon of {@link - * #L2_NORM_EPSILON}. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param The data type for x. - * @return the normalized values of x. - */ - public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { - return l2Normalize(tf, x, axes, L2_NORM_EPSILON); - } - - /** - * Normalizes along dimension axis using an L2 norm. - * - *

For a 1-D tensor with axis = 0, computes - * - *

-   *       output = x / sqrt(max(sum(x**2), epsilon))
-   * 
- * - *

For x with more dimensions, independently normalizes each 1-D slice along - * dimension axis. - * - * @param tf The TensorFlow ops - * @param x The operand to normalize - * @param axes Dimension(s) along which to normalize. - * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the - * divisor if norm < sqrt(epsilon). - * @param The data type for the values. - * @return the normalized values of x. - */ - public static Operand l2Normalize( - Ops tf, Operand x, int[] axes, float epsilon) { - Operand squareSum = - tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); - Operand y = - tf.math.rsqrt( - tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); - return tf.math.mul(x, y); - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 75a2031fbb5..af50b103a60 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -24,10 +24,10 @@ /** * A metric that computes the poisson loss metric between labels and predictions. * - * @param the data type for the predictions. + * @param The data type for the metric result. */ -public class Poisson extends MeanMetricWrapper +public class Poisson< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { /** @@ -46,7 +46,7 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.poisson(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 2e01f722de6..a0c016b70b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -24,12 +24,11 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. - * - * @param the data type for the predictions. + *\ * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class SparseCategoricalCrossentropy + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final int axis; @@ -55,7 +54,7 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 430dbbcc229..bd331a85eda 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -24,10 +24,9 @@ /** * A metric that computes the squared hinge loss metric between labels and predictions. * - * @param the data type for the predictions. * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper +public class SquaredHinge extends MeanMetricWrapper implements LossMetric { /** @@ -46,7 +45,7 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.squaredHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index b7b87d313aa..70bb8133698 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -29,8 +29,7 @@ public interface LossMetric { * * @param labels the truth values or labels * @param predictions the predictions - * @param The data type of the labels. * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 17c209a8fed..9a532a0294f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -17,13 +17,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.metrics.Mean; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. @@ -32,10 +33,9 @@ * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * - * @param the data type for the predictions. * @param The data type for the metric result */ -public class MeanMetricWrapper extends Mean { +public class MeanMetricWrapper extends Mean { /** The loss function interface */ protected LossMetric loss; @@ -85,22 +85,21 @@ protected void setLoss(LossMetric loss) { * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) - * @param the datatype of the labels - * @param the data type for sampleWeights * @return a List of control operations that updates the Mean state variables. */ - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { if (labels == null || predictions == null) { throw new IllegalArgumentException("missing required inputs for labels and predictions"); } - Operand tLabels = CastHelper.cast(getTF(), labels, getResultType()); - Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType()); + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); Operand losses = loss.call(tLabels, tPredictions); - return super.updateStateList( - CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); + return super.updateStateList(cast(getTF(), losses, predictions.type()), sampleWeights); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index ad8ff58e417..6cc089fce6d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -21,12 +21,10 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.math.Mean; import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.Collections; @@ -57,13 +55,13 @@ public class MetricsHelper { * @param values the values to which weights are applied. * @return Operation with control dependencies to ensure sampleWeight * can be broadcast to values - * @param the type of Operand + * @param the type of Operand * @throws NotBroadcastableException If static checks determine sampleWeights has an * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") - public static Op assertBroadcastable( - Ops tf, Operand sampleWeights, Operand values) { + public static Op assertBroadcastable( + Ops tf, Operand sampleWeights, Operand values) { // try static check for exact match @@ -129,7 +127,7 @@ public static Op assertBroadcastable( // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true. - Operand reshapedWeights = + Operand reshapedWeights = tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights); weightsShape = tf.shape(reshapedWeights); weightsRank = tf.rank(reshapedWeights); @@ -237,11 +235,10 @@ public static Operand mean(Ops tf, Operand x) { * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. * @param the type of the Operand. - * @param the type of the axes. * @return the mean of the operand, along the specified axes. */ - public static Operand mean( - Ops tf, Operand x, Operand axes) { + public static Operand mean( + Ops tf, Operand x, Operand axes) { return mean(tf, x, axes, false); } @@ -257,31 +254,27 @@ public static Operand mean( * @param the type of the operand * @return the mean of elements of x. */ - public static Operand mean( - Ops tf, Operand x, boolean keepDims) { + public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); } - - /** * Calculates the mean of the operand, alongside the specified axes. * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @param the data type of the Operand - * @param the data type of the axes * @return the mean of elements of x. */ - - public static Operand mean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand mean( + Ops tf, Operand x, Operand axes, boolean keepDims) { if (axes == null) { - axes = (Operand) allAxes(tf, x); + axes = allAxes(tf, x); } return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } @@ -294,7 +287,7 @@ public static Operand mean( * @param x the Operand used to calculate the mean * @return the mean of the operand containing floating point numbers */ - public static Operand booleanMean(Ops tf, Operand x) { + public static Operand booleanMean(Ops tf, Operand x) { return booleanMean(tf, x, null, false); } @@ -305,11 +298,10 @@ public static Operand booleanMean(Ops tf, Operand x) { * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param the type of the axes. * @return the mean of the operand, along the specified axes, containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x,Operand axes) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes) { return booleanMean(tf, x, axes, false); } @@ -317,14 +309,13 @@ public static Operand booleanMean( * Calculates the mean of the boolean operand, alongside all axes. * * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, boolean keepDims) { + public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); } @@ -333,16 +324,15 @@ public static Operand booleanMean( * * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the - * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the - * * reduced dimensions are retained with length 1. - * @param the data type of the axes + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. * @return the mean of elements of x containing floating point numbers */ - public static Operand booleanMean( - Ops tf, Operand x, Operand axes, boolean keepDims) { + public static Operand booleanMean( + Ops tf, Operand x, Operand axes, boolean keepDims) { Operand xf = cast(tf, x, TFloat64.class); return mean(tf, xf, axes, keepDims); } - } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 8e48cb4e573..2a26967b9f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -19,7 +19,6 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.Metric; import org.tensorflow.framework.metrics.MetricReduction; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -29,13 +28,14 @@ import java.util.ArrayList; import java.util.List; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Encapsulates metrics that perform a reduce operation on the metric values. * - * @param The data type for the metric values * @param The data type for the metric result */ -public abstract class Reduce extends Metric { +public abstract class Reduce extends Metric { public static final String TOTAL = "total"; public static final String COUNT = "count"; protected final MetricReduction reduction; @@ -45,8 +45,10 @@ public abstract class Reduce extends Metri private final Class resultType; /** the variable that holds the total of the metric values */ protected Variable total; - /** the variable that holds the count of the metric values. - * For {@link MetricReduction#WEIGHTED_MEAN}, this count may be weighted */ + /** + * the variable that holds the count of the metric values. For {@link + * MetricReduction#WEIGHTED_MEAN}, this count may be weighted + */ protected Variable count; /** @@ -95,12 +97,10 @@ private void setupVars() { public Op resetStates() { List controls = new ArrayList<>(); if (total != null) { - controls.add( - getTF().assign(total, CastHelper.cast(getTF(), getTF().constant(0), total.type()))); + controls.add(getTF().assign(total, cast(getTF(), getTF().constant(0), total.type()))); } if (count != null) { - controls.add( - getTF().assign(count, CastHelper.cast(getTF(), getTF().constant(0), count.type()))); + controls.add(getTF().assign(count, cast(getTF(), getTF().constant(0), count.type()))); } return getTF().withControlDependencies(controls).noOp(); } @@ -115,67 +115,67 @@ public Op resetStates() { * @throws IllegalArgumentException if values is null */ @Override - public List updateStateList(Operand values, Operand sampleWeights) { + public List updateStateList( + Operand values, Operand sampleWeights) { if (values == null) { throw new IllegalArgumentException("values is required."); } + Ops tf = getTF(); List updateOperations = new ArrayList<>(); // cast everything to match the variables - Operand lSampleWeights = null; - Operand lValues = values; + Operand tSampleWeights = null; + Operand tValues = cast(tf, values, getResultType()); if (sampleWeights != null) { - lSampleWeights = CastHelper.cast(getTF(), sampleWeights, lValues.type()); - LossTuple tuple = - LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); - lValues = tuple.getTarget(); - lSampleWeights = tuple.getSampleWeights(); + tSampleWeights = cast(getTF(), sampleWeights, getResultType()); + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, tSampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); try { - lSampleWeights = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + tSampleWeights = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); } catch (IllegalArgumentException ex) { // if we get here we have static shapes with either // different ranks or different dimension sizes. // first, reduce the values down to the rank of the samples - int valuesRank = lValues.shape().numDimensions(); - int weightsRank = lSampleWeights.shape().numDimensions(); + int valuesRank = tValues.shape().numDimensions(); + int weightsRank = tSampleWeights.shape().numDimensions(); int numAxes = Math.min(0, valuesRank - weightsRank); if (numAxes > 0) { // values rank is greater than weights rank, reduce values to weights rank. int[] axes = new int[numAxes]; for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; if (reduction == MetricReduction.SUM) { - lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + tValues = getTF().reduceSum(tValues, getTF().constant(axes)); } else { - lValues = getTF().math.mean(lValues, getTF().constant(axes)); + tValues = getTF().math.mean(tValues, getTF().constant(axes)); } } } - lValues = getTF().math.mul(lValues, lSampleWeights); + tValues = getTF().math.mul(tValues, tSampleWeights); } - Operand weightedValueSum = - getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); + Operand weightedValueSum = + getTF().reduceSum(tValues, LossesHelper.allAxes(getTF(), tValues)); Operand totalUpdate = - getTF().assignAdd(total, CastHelper.cast(getTF(), weightedValueSum, total.type())); + getTF().assignAdd(total, cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; if (reduction != MetricReduction.SUM) { switch (reduction) { case SUM_OVER_BATCH_SIZE: - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); break; case WEIGHTED_MEAN: - if (lSampleWeights == null) { - numValues = - CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + if (tSampleWeights == null) { + numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); } else { numValues = - CastHelper.cast( + cast( getTF(), getTF() - .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), + .reduceSum(tSampleWeights, LossesHelper.allAxes(getTF(), tSampleWeights)), resultType); } break; @@ -202,7 +202,7 @@ public Operand result() { break; case WEIGHTED_MEAN: case SUM_OVER_BATCH_SIZE: - fResult = getTF().math.divNoNan(total, CastHelper.cast(getTF(), count, resultType)); + fResult = getTF().math.divNoNan(total, cast(getTF(), count, resultType)); break; default: throw new UnsupportedOperationException( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java index 7ceedded018..be46bb5c282 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -32,7 +32,7 @@ class BinaryCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweighted", false, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -77,7 +77,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 0, 1, 0}; @@ -102,7 +102,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 0, 1, 0, 1, 1}; @@ -128,7 +128,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - BinaryCrossentropy instance = + BinaryCrossentropy instance = new BinaryCrossentropy<>( tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java index 2b4a1d75467..34fc3eef884 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java @@ -31,7 +31,7 @@ class CategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -55,7 +55,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testUnweightedLogits", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -104,7 +104,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>(tf, "CCE_testWeighted", true, 0, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 0, 0, 1}; @@ -129,7 +129,7 @@ public void testLabelSmoothing() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); float labelSmoothing = 0.1F; - CategoricalCrossentropy instance = + CategoricalCrossentropy instance = new CategoricalCrossentropy<>( tf, "CCE_testWeighted", true, labelSmoothing, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java index 87248d95e48..78b25a21b60 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java @@ -31,7 +31,7 @@ class CategoricalHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { @@ -64,7 +64,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalHinge instance = + CategoricalHinge instance = new CategoricalHinge<>(tf, "CH_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java index a9721ef2f8f..18410416c42 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -31,7 +31,7 @@ class CosineSimilarityTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -80,7 +80,7 @@ public void test_axis() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); int axis = 1; - CosineSimilarity instance = + CosineSimilarity instance = new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index 6af5fed4889..a9bd5fac76e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,7 +32,7 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = + Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = + Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java index 28020c0fa1c..267578a492c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -31,7 +31,7 @@ class KLDivergenceTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[][] trueArray = {{.5f, .8f, .12f}, {.7f, .43f, .8f}}; @@ -54,7 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - KLDivergence instance = + KLDivergence instance = new KLDivergence<>(tf, "KLD_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java index 31c043e0473..1b5b8fb7d49 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java @@ -32,7 +32,7 @@ class LogCoshErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); float[] trueArray = {1, 9, 2, -5, -2, 6}; @@ -56,7 +56,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - LogCoshError instance = + LogCoshError instance = new LogCoshError<>(tf, "LogCosh_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {1, 9, 2, -5, -2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java index 73241ecbe9f..984895f2ad9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java @@ -32,7 +32,7 @@ class MeanAbsoluteErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -74,7 +74,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanAbsoluteError instance = + MeanAbsoluteError instance = new MeanAbsoluteError<>(tf, "MAE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java index 4c92844b217..0b9e7f6b538 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java @@ -34,7 +34,7 @@ public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -76,7 +76,7 @@ public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { session.setEpsilon(1E-6f); Ops tf = session.getTF(); - MeanAbsolutePercentageError instance = + MeanAbsolutePercentageError instance = new MeanAbsolutePercentageError<>(tf, "MAPE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java index 0b760213015..e42052a9ef1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java @@ -33,7 +33,7 @@ class MeanSquaredErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); @@ -70,7 +70,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredError instance = + MeanSquaredError instance = new MeanSquaredError<>(tf, "MSE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java index 098a5cb9725..e68d63b8778 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java @@ -32,7 +32,7 @@ class MeanSquaredLogarithmicErrorTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); session.evaluate(0.0f, instance.getTotal()); @@ -69,7 +69,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MeanSquaredLogarithmicError instance = + MeanSquaredLogarithmicError instance = new MeanSquaredLogarithmicError<>(tf, "MSLE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); session.evaluate(0.0, instance.getTotal()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index cf3c3e44719..75d9ef93168 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -32,7 +32,7 @@ class PoissonTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = + Poisson instance = new Poisson<>(tf, "Poisson_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; @@ -55,7 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = + Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java index 87af1bd8448..8e1aaea0a8f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -32,7 +32,7 @@ class SparseCategoricalCrossentropyTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testUnweighted", false, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -56,7 +56,7 @@ public void testUnweighted() { public void testUnweightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); @@ -105,7 +105,7 @@ public void testWeighted() { public void testWeightedLogits() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); session.run(instance.resetStates()); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java index e3376c224f3..2c80b3451ad 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java @@ -32,7 +32,7 @@ class SquaredHingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testUnweighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { @@ -61,7 +61,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SquaredHinge instance = + SquaredHinge instance = new SquaredHinge<>(tf, "SCE_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { From d7f7e4c5871b3a418e9e2335b95a3c9e0009213b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 14:42:48 -0500 Subject: [PATCH 02/97] Reformat code --- .../annotations/org/tensorflow/op/Ops.java | 6 +- .../losses/CategoricalCrossentropy.java | 36 ++++++------ .../framework/losses/CategoricalHinge.java | 4 +- .../tensorflow/framework/losses/Hinge.java | 20 ++++--- .../tensorflow/framework/losses/Huber.java | 2 +- .../framework/losses/KLDivergence.java | 2 +- .../tensorflow/framework/losses/LogCosh.java | 2 +- .../org/tensorflow/framework/losses/Loss.java | 5 +- .../tensorflow/framework/losses/Losses.java | 28 +++++++--- .../framework/losses/MeanAbsoluteError.java | 2 +- .../losses/MeanAbsolutePercentageError.java | 2 +- .../framework/losses/MeanSquaredError.java | 2 +- .../losses/MeanSquaredLogarithmicError.java | 2 +- .../losses/SparseCategoricalCrossentropy.java | 34 +++++++----- .../framework/losses/SquaredHinge.java | 18 +++--- .../framework/losses/impl/LossTuple.java | 2 +- .../framework/losses/impl/LossesHelper.java | 26 ++++----- .../framework/metrics/BinaryCrossentropy.java | 9 ++- .../framework/metrics/CategoricalHinge.java | 4 +- .../framework/metrics/CosineSimilarity.java | 2 +- .../tensorflow/framework/metrics/Hinge.java | 3 +- .../framework/metrics/KLDivergence.java | 5 +- .../framework/metrics/LogCoshError.java | 3 +- .../framework/metrics/MeanAbsoluteError.java | 2 +- .../metrics/MeanAbsolutePercentageError.java | 6 +- .../framework/metrics/MeanSquaredError.java | 4 +- .../metrics/MeanSquaredLogarithmicError.java | 6 +- .../tensorflow/framework/metrics/Poisson.java | 6 +- .../SparseCategoricalCrossentropy.java | 13 +++-- .../framework/metrics/SquaredHinge.java | 5 +- .../framework/metrics/impl/LossMetric.java | 2 +- .../framework/metrics/impl/SetsOps.java | 55 ++++++++++--------- .../framework/metrics/HingeTest.java | 6 +- .../framework/metrics/PoissonTest.java | 3 +- .../SparseCategoricalCrossentropyTest.java | 2 +- 35 files changed, 174 insertions(+), 155 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 84736ada6a5..007ee9d0d42 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -345,10 +345,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -370,8 +370,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 035af9589ae..5aac163c1e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -154,24 +154,26 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, boolean fromLogits, float labelSmoothing) { this(tf, null, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); } /** - * Creates a categorical cross entropy Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and a channel axis of {@link #DEFAULT_AXIS} + * Creates a categorical cross entropy Loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and a channel axis of {@link #DEFAULT_AXIS} * * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 */ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float labelSmoothing) { this(tf, name, fromLogits, labelSmoothing, REDUCTION_DEFAULT, DEFAULT_AXIS); @@ -183,9 +185,10 @@ public CategoricalCrossentropy(Ops tf, String name, boolean fromLogits, float la * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. x=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. x=0.2 means + * that we will use a value of 0.1 for label 0 and 0.9 + * for label 1 * @param reduction Type of Reduction to apply to loss. */ public CategoricalCrossentropy( @@ -199,13 +202,14 @@ public CategoricalCrossentropy( * @param tf the TensorFlow Ops * @param name the name of this loss * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. * @param axis The channels axis. axis=-1 corresponds to data format "Channels Last" - * and axis=1 corresponds to data format "Channels First". - * {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} + * and axis=1 corresponds to data format "Channels First". {@link + * Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 4e9133d8835..73837ed1756 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -25,7 +25,7 @@ *

loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions) * and pos=sum(labels*predictions) * - *

labels values are expected to be 0 or 1.

+ *

labels values are expected to be 0 or 1. * *

Standalone usage: * @@ -100,7 +100,7 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 37e7e367b9b..db3569441ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -18,15 +18,16 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the hinge loss between labels and predictions. * - *

loss = maximum(1 - labels * predictions, 0)

. + *

loss = maximum(1 - labels * predictions, 0). * - *

labels values are expected to be -1 or 1. - * If binary (0 or 1) labels are provided, they will be converted to -1 or 1.

+ *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -106,7 +107,7 @@ public Hinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor @@ -124,13 +125,16 @@ public Hinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + Operand tLabels = + predictions.type() == labels.type() + ? (Operand) labels + : cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index e8de632eb09..665a9ac157d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -131,7 +131,7 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index b3c0206b409..2aa1f72092b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -100,7 +100,7 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 812260d9881..78325713e3e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -106,7 +106,7 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index 0f9b183f38c..cdd35d28aba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -25,7 +25,7 @@ public abstract class Loss { protected final Reduction reduction; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops @@ -64,7 +64,8 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param The data type of the predictions and loss. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { return call(labels, predictions, null); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index a5ced3d1df8..2222ebb41f8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -102,8 +102,10 @@ public static Operand meanAbsolutePercentageError( tf.math.abs( tf.math.div( tf.math.sub(tLabels, predictions), - tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); - return tf.math.mul(cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); + tf.math.maximum( + tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); + return tf.math.mul( + cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -149,7 +151,11 @@ public static Operand meanSquaredLogarithmicError( * @return the binary crossentropy loss. */ public static Operand binaryCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -214,9 +220,10 @@ private static Operand binaryCrossentropyHelper( * @param labels true targets * @param predictions the predictions * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param axis the * @param the data type of the predictions and labels * @return the categorical crossentropy loss. @@ -503,7 +510,11 @@ public static Operand poisson( * @return the sparse categorical crossentropy loss */ public static Operand sparseCategoricalCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -650,8 +661,7 @@ public static Operand l2Normalize(Ops tf, Operand x, i Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = - tf.math.rsqrt( - tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); + tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); return tf.math.mul(x, invNorm); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 594de1e1448..03a3cf70110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -96,7 +96,7 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 275a2e136a0..6c5242df4f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -96,7 +96,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 31df3e70e0b..f975db55c44 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -96,7 +96,7 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index bef990d22bc..11b8e157e90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -96,7 +96,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java index 3ec33113e89..d04cc67d5d9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -79,7 +80,8 @@ public class SparseCategoricalCrossentropy extends Loss { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops */ @@ -88,8 +90,8 @@ public SparseCategoricalCrossentropy(Ops tf) { } /** - * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, - * and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy loss using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -122,8 +124,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, Reduction reduction) { } /** - * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}, and - * fromLogits={@link #FROM_LOGITS_DEFAULT}. + * Creates a SparseCategoricalCrossentropy using a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT}, and fromLogits={@link #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param name the name of this loss function @@ -135,7 +137,8 @@ public SparseCategoricalCrossentropy(Ops tf, String name, boolean fromLogits) { /** * Creates a SparseCategoricalCrossentropy loss using {@link Class#getSimpleName()} as the loss - * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link #FROM_LOGITS_DEFAULT}. + * name, a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} and fromLogits={@link + * #FROM_LOGITS_DEFAULT}. * * @param tf the TensorFlow Ops * @param fromLogits Whether to interpret predictions as a tensor of logit values @@ -176,9 +179,10 @@ public SparseCategoricalCrossentropy( /** * Generates an Operand the calculates the loss. * - * If run in Graph mode, the computation will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException} - * if the predictions values are outside the range o [0. to 1.]. In Eager Mode, this call - * will throw {@link IllegalArgumentException}, if the predictions values are outside the range o [0. to 1.] + *

If run in Graph mode, the computation will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException} if the predictions values are outside the + * range o [0. to 1.]. In Eager Mode, this call will throw {@link IllegalArgumentException}, if + * the predictions values are outside the range o [0. to 1.] * * @param labels the truth values or labels * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. @@ -200,12 +204,12 @@ public Operand call( if (!fromLogits) { // add predictions range check for 0 - 1 lPredictions = - LossesHelper.rangeCheck( - getTF(), - "predictions range check [0-1]", - predictions, - cast(getTF(), getTF().constant(0), predictions.type()), - cast(getTF(), getTF().constant(1), predictions.type())); + LossesHelper.rangeCheck( + getTF(), + "predictions range check [0-1]", + predictions, + cast(getTF(), getTF().constant(0), predictions.type()), + cast(getTF(), getTF().constant(1), predictions.type())); } else { lPredictions = predictions; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java index 968624db202..dadbdb3b95e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/SquaredHinge.java @@ -18,6 +18,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** @@ -25,8 +26,8 @@ * *

loss = square(maximum(1 - labels * predictions, 0)) * - *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, they will be - * converted to -1 or 1. + *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -107,7 +108,7 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional SampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If SampleWeights is a tensor @@ -124,13 +125,16 @@ public SquaredHinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + Operand tLabels = + predictions.type() == labels.type() + ? (Operand) labels + : cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.squaredHinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java index 2104937a979..f811549fbca 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossTuple.java @@ -18,7 +18,7 @@ import org.tensorflow.types.family.TNumber; /** - * A helper class for loss methods to return labels, target, and sampleWeights + * A helper class for loss methods to return labels, target, and sampleWeights * * @param the data type of the LossTuple entries. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 10067db91ba..66bdd839f09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -32,8 +32,9 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to - * TensorFlow Java. These methods should not be used outside of the losses and metrics packages. + * These are helper methods for Losses and Metrics and will be module private when Java modularity + * is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics + * packages. */ public class LossesHelper { @@ -42,10 +43,10 @@ public class LossesHelper { * *

    *
  1. Squeezes last dim of predictions or labels if their rank - * differs by 1 (using {@link #removeSqueezableDimensions}).
  2. + * differs by 1 (using {@link #removeSqueezableDimensions}). *
  3. Squeezes or expands last dim of sampleWeight if its rank differs by 1 from * the new rank of predictions. If sampleWeight is scalar, it is - * kept scalar.
  4. + * kept scalar. *
* * @param tf the TensorFlow Ops @@ -77,12 +78,13 @@ public static LossTuple squeezeOrExpandDimensions( * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . - * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * @param sampleWeights Optional sample weight(s) Operand whose dimensions match + * * prediction. - * @return LossTuple of predictions, labels and sampleWeight. - * Each of them possibly has the last dimension squeezed, sampleWeight could be - * extended by one dimension. If sampleWeight is null, only the possibly shape modified predictions and labels are - * returned. + * @return LossTuple of predictions, labels and sampleWeight + * . Each of them possibly has the last dimension squeezed, sampleWeight + * could be extended by one dimension. If sampleWeight is null, only the possibly + * shape modified predictions and labels are returned. */ public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions, Operand sampleWeights) { @@ -298,8 +300,7 @@ private static Operand reduceWeightedLoss( public static Operand safeMean( Ops tf, Operand losses, long numElements) { Operand totalLoss = tf.reduceSum(losses, allAxes(tf, losses)); - return tf.math.divNoNan( - totalLoss, cast(tf, tf.constant(numElements), losses.type())); + return tf.math.divNoNan(totalLoss, cast(tf, tf.constant(numElements), losses.type())); } /** @@ -383,8 +384,7 @@ public static Operand rangeCheck( */ public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { - Operand flatValues = - tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); + Operand flatValues = tf.reshape(values, tf.constant(Shape.of(values.shape().size()))); SetDiff1d diff = tf.setDiff1d(flatValues, allowedValues, TInt32.class); long diffSize = diff.out().shape().size(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index abd2dcbbf40..d8bb2a41116 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,8 +21,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * @@ -31,8 +29,8 @@ * * @param The data type for the metric result */ -public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class BinaryCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -42,7 +40,8 @@ public class BinaryCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index c70f2d8643b..4800fc43c49 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result */ -public class CategoricalHinge< T extends TNumber> extends MeanMetricWrapper +public class CategoricalHinge extends MeanMetricWrapper implements LossMetric { /** @@ -45,7 +45,7 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 5abbd095420..3ae67072955 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class CosineSimilarity< T extends TNumber> extends MeanMetricWrapper< T> +public class CosineSimilarity extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index e0aced6fa3e..3b84b81e071 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -26,8 +26,7 @@ * * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper - implements LossMetric { +public class Hinge extends MeanMetricWrapper implements LossMetric { /** * Creates a Hinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index fa09f2784b5..f631f562e1d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -27,8 +27,7 @@ * * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper - implements LossMetric { +public class KLDivergence extends MeanMetricWrapper implements LossMetric { /** * Creates a KLDivergence metric @@ -46,7 +45,7 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index c43551a6948..046937e228b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -27,8 +27,7 @@ * * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper< T> - implements LossMetric { +public class LogCoshError extends MeanMetricWrapper implements LossMetric { /** * Creates a LogCoshError metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index d343ec77ab0..977f61648a1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper< T> +public class MeanAbsoluteError extends MeanMetricWrapper implements LossMetric { /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index dd7d151260b..bad5255969a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -26,8 +26,8 @@ * * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { +public class MeanAbsolutePercentageError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -45,7 +45,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c2bef576b30..5b0d9ec43b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class MeanSquaredError< T extends TNumber> extends MeanMetricWrapper< T> +public class MeanSquaredError extends MeanMetricWrapper implements LossMetric { /** @@ -45,7 +45,7 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index c1cf4ca6c9a..35044fee956 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -26,8 +26,8 @@ * * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { +public class MeanSquaredLogarithmicError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -45,7 +45,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index af50b103a60..700099d3375 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -24,11 +24,9 @@ /** * A metric that computes the poisson loss metric between labels and predictions. * - * @param The data type for the metric result. */ -public class Poisson< T extends TNumber> extends MeanMetricWrapper< T> - implements LossMetric { +public class Poisson extends MeanMetricWrapper implements LossMetric { /** * Creates a Poisson metric @@ -46,7 +44,7 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.poisson(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index a0c016b70b3..aa7ca316378 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -23,12 +23,12 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and - * predicted labels. - *\ + * predicted labels. \ + * * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class SparseCategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final int axis; @@ -38,7 +38,8 @@ public class SparseCategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param axis The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -54,7 +55,7 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index bd331a85eda..01f4a403f84 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -26,8 +26,7 @@ * * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper - implements LossMetric { +public class SquaredHinge extends MeanMetricWrapper implements LossMetric { /** * Creates a SquaredHinge metric @@ -45,7 +44,7 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.squaredHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 70bb8133698..037d634cd4a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -31,5 +31,5 @@ public interface LossMetric { * @param predictions the predictions * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 1841c7ee238..467dea19b57 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -25,33 +25,6 @@ /** Implementation of set operations */ public class SetsOps { - /** - * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops - * function {@link SparseOps#denseToDenseSetOperation} - */ - public enum Operation { - A_MINUS_B("a-b"), - B_MINUS_A("b-a"), - INTERSECTION("intersection"), - UNION("union"); - - private final String setOperation; - - Operation(String setOperation) { - this.setOperation = setOperation; - } - - /** - * Gets the set operation String value used to pass as the stringOperation value to {@link - * SparseOps#denseToDenseSetOperation} - * - * @return the set operation String value - */ - public String getSetOperation() { - return setOperation; - } - } - /** * Computes set difference of elements in last dimension of a and b with * aMinusB set to true. @@ -69,6 +42,7 @@ public String getSetOperation() { public static Operand difference(Ops tf, Operand a, Operand b) { return difference(tf, a, b, true); } + /** * Computes set difference of elements in last dimension of a and b. * @@ -143,4 +117,31 @@ public static Operand setOperation( setOperationResult.resultValues(), cast(tf, tf.constant(0), a.type())); } + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index a9bd5fac76e..90531d21fde 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,8 +32,7 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; @@ -55,8 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { -1, 1, -1, 1, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index 75d9ef93168..5631bac15ee 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -55,8 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = - new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); + Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; float[] predArray = {1, 9, 2, 5, 2, 6}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java index 8e1aaea0a8f..0aece8c8ac9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -79,7 +79,7 @@ public void testUnweightedLogits() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SparseCategoricalCrossentropy instance = + SparseCategoricalCrossentropy instance = new SparseCategoricalCrossentropy<>( tf, "SCE_testWeighted", false, -1, 1001L, TFloat32.class); session.run(instance.resetStates()); From 6b4149ce0036fca5a3dfe33f283b1e6a980a7f9d Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 18:40:00 -0500 Subject: [PATCH 03/97] Change order of TrainOps and QuantiQuantizationOps. For some reason, when I build it reverses these 2 from master's version. --- .../src/gen/annotations/org/tensorflow/op/Ops.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 007ee9d0d42..84736ada6a5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -345,10 +345,10 @@ public final class Ops { public final SignalOps signal; - public final QuantizationOps quantization; - public final TrainOps train; + public final QuantizationOps quantization; + private final Scope scope; private Ops(Scope scope) { @@ -370,8 +370,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - quantization = new QuantizationOps(this); train = new TrainOps(this); + quantization = new QuantizationOps(this); } /** From e486a9038eb22ae05b3e33ee63f8c371f0b509c6 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Feb 2021 15:32:07 -0500 Subject: [PATCH 04/97] Fix LossMetric to change abstract "call" method to use gneric parameter for predictions instead of . --- .../framework/metrics/BinaryCrossentropy.java | 8 ++++++-- .../framework/metrics/CategoricalCrossentropy.java | 8 ++++++-- .../framework/metrics/CategoricalHinge.java | 8 ++++++-- .../framework/metrics/CosineSimilarity.java | 11 ++++++++--- .../java/org/tensorflow/framework/metrics/Hinge.java | 8 ++++++-- .../tensorflow/framework/metrics/KLDivergence.java | 8 ++++++-- .../tensorflow/framework/metrics/LogCoshError.java | 8 ++++++-- .../framework/metrics/MeanAbsoluteError.java | 8 ++++++-- .../metrics/MeanAbsolutePercentageError.java | 8 ++++++-- .../framework/metrics/MeanSquaredError.java | 8 ++++++-- .../metrics/MeanSquaredLogarithmicError.java | 8 ++++++-- .../org/tensorflow/framework/metrics/Poisson.java | 8 ++++++-- .../metrics/SparseCategoricalCrossentropy.java | 8 ++++++-- .../tensorflow/framework/metrics/SquaredHinge.java | 8 ++++++-- .../tensorflow/framework/metrics/impl/LossMetric.java | 2 +- 15 files changed, 87 insertions(+), 30 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index d8bb2a41116..263b8a789ed 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * @@ -60,7 +62,9 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index be43f34b92e..cbe0127295f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical cross-entropy loss between true labels and predicted * labels. @@ -99,8 +101,10 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalCrossentropy( - getTF(), labels, predictions, fromLogits, labelSmoothing, axis); + getTF(), tLabels, tPredictions, fromLogits, labelSmoothing, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 4800fc43c49..ff814ae6ed3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * @@ -45,7 +47,9 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.categoricalHinge(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.categoricalHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 3ae67072955..d64136d0d90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the cosine similarity metric between labels and predictions. * @@ -76,8 +78,11 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity - return Losses.cosineSimilarity(getTF(), labels, predictions, axis); + public Operand call(Operand labels, Operand predictions) { + // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, + // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.cosineSimilarity(getTF(), tLabels, tPredictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 3b84b81e071..7a37cbeddbe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the hinge loss metric between labels and predictions. * @@ -44,7 +46,9 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.hinge(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.hinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index f631f562e1d..3027bb2f460 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. @@ -45,7 +47,9 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 046937e228b..ca84e651988 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. @@ -45,7 +47,9 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.logCosh(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.logCosh(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 977f61648a1..c91cb0df1ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsoluteError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index bad5255969a..6cc96a4fb88 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 5b0d9ec43b3..1fce9998270 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 35044fee956..900359db88b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 700099d3375..3572c155b96 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the poisson loss metric between labels and predictions. * @@ -44,7 +46,9 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.poisson(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.poisson(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index aa7ca316378..a74f575a4a8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. \ @@ -55,7 +57,9 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 01f4a403f84..6bee2ccf8e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the squared hinge loss metric between labels and predictions. * @@ -44,7 +46,9 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.squaredHinge(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.squaredHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 037d634cd4a..1fb3d3bb580 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -31,5 +31,5 @@ public interface LossMetric { * @param predictions the predictions * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } From c7115323dce138c6ed6ced16c3aaf435e8cc046e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 6 Feb 2021 16:17:11 -0500 Subject: [PATCH 05/97] Reformat code, fix javadoc --- .../annotations/org/tensorflow/op/Ops.java | 6 ++-- .../framework/initializers/Glorot.java | 1 - .../tensorflow/framework/initializers/He.java | 1 - .../tensorflow/framework/losses/Hinge.java | 7 +--- .../tensorflow/framework/losses/Huber.java | 4 ++- .../tensorflow/framework/losses/LogCosh.java | 3 +- .../tensorflow/framework/losses/Losses.java | 1 + .../tensorflow/framework/losses/Poisson.java | 3 +- .../framework/losses/impl/LossesHelper.java | 7 +++- .../framework/metrics/BinaryCrossentropy.java | 3 +- .../metrics/CategoricalCrossentropy.java | 3 +- .../framework/metrics/CategoricalHinge.java | 3 +- .../framework/metrics/CosineSimilarity.java | 3 +- .../tensorflow/framework/metrics/Hinge.java | 3 +- .../framework/metrics/KLDivergence.java | 3 +- .../framework/metrics/LogCoshError.java | 3 +- .../framework/metrics/MeanAbsoluteError.java | 3 +- .../metrics/MeanAbsolutePercentageError.java | 3 +- .../framework/metrics/MeanSquaredError.java | 3 +- .../metrics/MeanSquaredLogarithmicError.java | 3 +- .../tensorflow/framework/metrics/Metric.java | 6 +++- .../tensorflow/framework/metrics/Poisson.java | 3 +- .../SparseCategoricalCrossentropy.java | 3 +- .../framework/metrics/SquaredHinge.java | 3 +- .../framework/metrics/impl/MetricsHelper.java | 2 ++ .../framework/optimizers/AdaDelta.java | 34 +++++++++---------- .../framework/optimizers/AdaGrad.java | 8 ++--- .../framework/optimizers/AdaGradDA.java | 2 +- .../framework/optimizers/Adamax.java | 4 +-- .../tensorflow/framework/optimizers/Ftrl.java | 10 +++--- .../framework/optimizers/Nadam.java | 5 ++- .../framework/optimizers/Optimizer.java | 20 ++++++----- .../framework/optimizers/RMSProp.java | 27 +++++++-------- .../framework/utils/CastHelper.java | 2 +- .../framework/utils/ShapeUtils.java | 28 ++++++++------- 35 files changed, 124 insertions(+), 99 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 84736ada6a5..007ee9d0d42 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -345,10 +345,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -370,8 +370,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 290e4e80b57..894bd073758 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -62,7 +62,6 @@ * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *

For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. - *

* * @param The TType for the call operation * @see VarianceScaling.Distribution diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index 9b1a0887af0..3a91b72b0d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -57,7 +57,6 @@ * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. - *

* * @param The TType for the call operation * @see The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - @SuppressWarnings("unchecked") - Operand tLabels = - predictions.type() == labels.type() - ? (Operand) labels - : cast(tf, labels, predictions.type()); + Operand tLabels = cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( getTF(), diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 665a9ac157d..b1aee1b0656 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -89,6 +89,7 @@ public Huber(Ops tf) { * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Huber(Ops tf, String name) { this(tf, name, DELTA_DEFAULT, Reduction.AUTO); @@ -109,6 +110,7 @@ public Huber(Ops tf, Reduction reduction) { * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Huber(Ops tf, String name, Reduction reduction) { @@ -119,7 +121,7 @@ public Huber(Ops tf, String name, Reduction reduction) { * Creates a Huber Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param delta the point where the Huber loss function changes from quadratic to linear. * @param reduction Type of Reduction to apply to the loss. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 78325713e3e..a11d582e527 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -77,6 +77,7 @@ public LogCosh(Ops tf) { * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public LogCosh(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -96,7 +97,7 @@ public LogCosh(Ops tf, Reduction reduction) { * Creates a LogCosh Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public LogCosh(Ops tf, String name, Reduction reduction) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 2222ebb41f8..9aa94cf7fcf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -655,6 +655,7 @@ private static Operand smoothCategoricalLabels( * @param tf The TensorFlow Ops * @param x the input * @param axis Dimension along which to normalize. + * @param the data type for the input and the result * @return the normalized values based on L2 norm */ public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java index 9cf38aa0380..78324acf8a5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Poisson.java @@ -76,6 +76,7 @@ public Poisson(Ops tf) { * Creates a Poisson Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Poisson(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -95,7 +96,7 @@ public Poisson(Ops tf, Reduction reduction) { * Creates a Poisson Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Poisson(Ops tf, String name, Reduction reduction) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 66bdd839f09..f6b0de71b0d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -53,6 +53,7 @@ public class LossesHelper { * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . + * @param the data type for the labels, predictions and result * @return LossTuple of prediction, label,sampleWeight will * be null. Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, (prediction, @@ -81,6 +82,7 @@ public static LossTuple squeezeOrExpandDimensions( * @param sampleWeights Optional sample weight(s) Operand whose dimensions match * * prediction. + * @param the data type for the labels, predictions and result * @return LossTuple of predictions, labels and sampleWeight * . Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, only the possibly @@ -180,6 +182,7 @@ private static Operand maybeExpandWeights( * @param labels Label values, a Tensor whose dimensions match predictions * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -195,6 +198,7 @@ public static LossTuple removeSqueezableDimensions( *
. * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -218,7 +222,8 @@ public static LossTuple removeSqueezableDimensions( } // Use dynamic rank. - // TODO Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + // TODO: hold for lazy select feature, + // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { /* * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 263b8a789ed..48ee244eafb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -62,7 +62,8 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index cbe0127295f..b22e5415f79 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -101,7 +101,8 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalCrossentropy( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index ff814ae6ed3..4266cc487c0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -47,7 +47,8 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalHinge(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index d64136d0d90..840f255c5ab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -78,7 +78,8 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity Operand tLabels = cast(getTF(), labels, getResultType()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 7a37cbeddbe..46ccd2859ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -46,7 +46,8 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.hinge(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 3027bb2f460..9ffcd6189f1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -47,7 +47,8 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index ca84e651988..59e24f57110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -47,7 +47,8 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.logCosh(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index c91cb0df1ef..1cc6d0b6f99 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -47,7 +47,8 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 6cc96a4fb88..8c6720b58f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -47,7 +47,8 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 1fce9998270..3c4c79d39ba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -47,7 +47,8 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanSquaredError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 900359db88b..d525bb76648 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -47,7 +47,8 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java index 8ab21c58218..468919e696d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -184,7 +184,11 @@ public String getName() { return name; } - /** The random number generator seed value */ + /** + * Gets the random number generator seed value + * + * @return the random number generator seed value + */ public long getSeed() { return seed; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 3572c155b96..422fd4808ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -46,7 +46,8 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.poisson(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index a74f575a4a8..9949f0c6b60 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -57,7 +57,8 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 6bee2ccf8e4..19b3b1d0ac4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -46,7 +46,8 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.squaredHinge(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 6cc089fce6d..8a352322f52 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -308,6 +308,7 @@ public static Operand booleanMean( /** * Calculates the mean of the boolean operand, alongside all axes. * + * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is * false, the rank of the tensor is reduced by 1 for each entry in axes @@ -322,6 +323,7 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke /** * Calculates the mean of the boolean operand, alongside the specified axes. * + * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index 822eb490f22..aadbfeea54b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -31,29 +31,29 @@ * learning rate per dimension to address two drawbacks: * *
    - *
  • the continual decay of learning rates throughout training - *
  • the need for a manually selected global learning rate + *
  • the continual decay of learning rates throughout training + *
  • the need for a manually selected global learning rate *
* - *

Adadelta is a more robust extension of Adagrad that adapts learning rates based on a - * moving window of gradient updates, instead of accumulating all past gradients. This way, - * Adadelta continues learning even when many updates have been done. Compared to Adagrad, in - * the original version of Adadelta you don't have to set an initial learning rate. In this - * version, initial learning rate can be set, as in most other optimizers. + *

Adadelta is a more robust extension of Adagrad that adapts learning rates based on a moving + * window of gradient updates, instead of accumulating all past gradients. This way, Adadelta + * continues learning even when many updates have been done. Compared to Adagrad, in the original + * version of Adadelta you don't have to set an initial learning rate. In this version, initial + * learning rate can be set, as in most other optimizers. * - *

According to section 4.3 ("Effective Learning rates"), near the end of training step sizes - * converge to 1 which is effectively a high learning rate which would cause divergence. This - * occurs only near the end of the training as gradients and step sizes are small, and the - * epsilon constant in the numerator and denominator dominate past gradients and parameter - * updates which converge the learning rate to 1. + *

According to section 4.3 ("Effective Learning rates"), near the end of training step sizes + * converge to 1 which is effectively a high learning rate which would cause divergence. This occurs + * only near the end of the training as gradients and step sizes are small, and the epsilon constant + * in the numerator and denominator dominate past gradients and parameter updates which converge the + * learning rate to 1. * - *

According to section 4.4("Speech Data"),where a large neural network with 4 hidden layers - * was trained on a corpus of US English data, ADADELTA was used with 100 network replicas.The - * epsilon used is 1e-6 with rho=0.95 which converged faster than ADAGRAD, by the following - * construction: new AdaDelta(graph, 1.0f, 0.95f, 1e-6f); + *

According to section 4.4("Speech Data"),where a large neural network with 4 hidden layers was + * trained on a corpus of US English data, ADADELTA was used with 100 network replicas.The epsilon + * used is 1e-6 with rho=0.95 which converged faster than ADAGRAD, by the following construction: + * new AdaDelta(graph, 1.0f, 0.95f, 1e-6f); * * @see Zeiler, M., 2012 ADADELTA: An Adaptive Learning - * Rate Method. + * Rate Method */ public class AdaDelta extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 08f5f18a9cd..2dd05ef31b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -31,10 +31,10 @@ * how frequently a parameter gets updated during training. The more updates a parameter receives, * the smaller the updates. * - *

- * - * @see Duchi, J, et al., 2011, Adaptive Subgradient Methods for Online Learning and Stochastic Optimization - * @see Duchi, J, et al., 2013, Proximal and First-Order Methods for Convex Optimization, Introduction Section 1. + * @see Duchi, J, et al., 2011, + * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization + * @see Duchi, J, et al., + * 2013, Proximal and First-Order Methods for Convex Optimization, Introduction Section 1 */ public class AdaGrad extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index df624e41c4e..7114c33339f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -40,7 +40,7 @@ * networks as it will require careful initialization of the gradient accumulators for it to train. * * @see Duchi, J, et al., 2011, - * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization. + * Adaptive Subgradient Methods for Online Learning and Stochastic Optimization */ public class AdaGradDA extends Optimizer { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index cd95bb3bd07..0ecc1ac1451 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -32,12 +32,10 @@ public class Adamax extends Optimizer { public static final float EPSILON_DEFAULT = 1e-07f; public static final float BETA_ONE_DEFAULT = 0.9f; public static final float BETA_TWO_DEFAULT = 0.999f; - - private float learningRate; private final float betaOne; private final float betaTwo; private final float epsilon; - + private final float learningRate; private Constant learningRateConst; private Constant epsilonConst; private Constant betaOneConst; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java index 66314d2ffe0..5d8c1478231 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Ftrl.java @@ -13,10 +13,11 @@ /** * Optimizer that implements the FTRL algorithm. * + *

This version has support for both online L2 (the L2 penalty given in the paper below) and + * shrinkage-type L2 (which is the addition of an L2 penalty to the loss function). + * * @see McMahan, et - * al., 2013, Algorithm 1 - *

This version has support for both online L2 (the L2 penalty given in the paper above) and - * shrinkage-type L2 (which is the addition of an L2 penalty to the loss function). + * al., 2013, Algorithm 1 */ public class Ftrl extends Optimizer { @@ -29,13 +30,12 @@ public class Ftrl extends Optimizer { public static final float L1STRENGTH_DEFAULT = 0.0f; public static final float L2STRENGTH_DEFAULT = 0.0f; public static final float L2_SHRINKAGE_REGULARIZATION_STRENGTH_DEFAULT = 0.0f; - - private float learningRate; private final float learningRatePower; private final float initialAccumulatorValue; private final float l1RegularizationStrength; private final float l2RegularizationStrength; private final float l2ShrinkageRegularizationStrength; + private final float learningRate; /** * Creates a Ftrl Optimizer diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java index f9900a8ee78..5b94b548c0a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Nadam.java @@ -24,8 +24,6 @@ */ public class Nadam extends Optimizer { - private static final float DECAY_BASE = 0.96f; - private static final float DECAY = 0.004f; public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float EPSILON_DEFAULT = 1e-8f; public static final float BETA_ONE_DEFAULT = 0.9f; @@ -33,7 +31,8 @@ public class Nadam extends Optimizer { public static final String FIRST_MOMENT = "m"; public static final String SECOND_MOMENT = "v"; public static final String MOMENTUM = "momentum"; - + private static final float DECAY_BASE = 0.96f; + private static final float DECAY = 0.004f; /** The learning rate. */ private final float learningRate; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java index fdf56da4a67..ed141831bbe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Optimizer.java @@ -71,14 +71,6 @@ protected Optimizer(Graph graph, String name) { this.globals = new ArrayList<>(); } - /** - * Gets the Optimizer's Ops instance - * @return the Optimizer's Ops instance - */ - public final Ops getTF() { - return tf; - } - /** * Creates a name by combining a variable name and a slot name * @@ -90,6 +82,15 @@ public static String createName(Output variable, String slotNam return variable.op().name() + "-" + slotName; } + /** + * Gets the Optimizer's Ops instance + * + * @return the Optimizer's Ops instance + */ + public final Ops getTF() { + return tf; + } + /** * Minimizes the loss by updating the variables * @@ -299,7 +300,8 @@ private Options() {} * Sets the shared name * * @param sharedName If non-empty, this variable is named in the given bucket with this - * shared_name. Otherwise, the node name is used instead. + * sharedName. Otherwise, the node name is used instead. + * @return this options instance */ public Optimizer.Options sharedName(String sharedName) { this.sharedName = sharedName; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index b3729dc367f..e86e64971a4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -27,17 +27,20 @@ /** * Optimizer that implements the RMSProp algorithm. * - *

The gist of RMSprop is to:

    - *
  • Maintain a moving (discounted) average of the square of gradients - *
  • Divide the gradient by the root of this average
+ *

The gist of RMSprop is to: * - *

This implementation of RMSprop uses plain momentum, not Nesterov momentum. + *

    + *
  • Maintain a moving (discounted) average of the square of gradients + *
  • Divide the gradient by the root of this average + *
* - *

The centered version additionally maintains a moving average of the gradients, and uses - * that average to estimate the variance. + *

This implementation of RMSprop uses plain momentum, not Nesterov momentum. + * + *

The centered version additionally maintains a moving average of the gradients, and uses that + * average to estimate the variance. * * @see Hinton G, - * et al. 2012, lecture notes that is inexplicably the canonical reference. + * et al. 2012, lecture notes, that is inexplicably the canonical reference. */ public class RMSProp extends Optimizer { @@ -165,24 +168,20 @@ protected void createSlots(List> variables) { } } - /** - * Creates the RMSProp Slots for Root Mean Squared (RMS), - * MOMENTUM, and Mean Gradient (MG) + * Creates the RMSProp Slots for Root Mean Squared (RMS), MOMENTUM, and Mean Gradient (MG) * * @param v the variable to install in the slot * @param the datatype of the variable. */ private void createRMSPropSlot(Output v) { - Operand rmsInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); + Operand rmsInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(1.0f), v.type())); createSlot(v.asOutput(), RMS, rmsInitializer); Operand momentumInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MOMENTUM, momentumInitializer); if (centered) { - Operand mgInitializer = - tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); + Operand mgInitializer = tf.fill(tf.shape(v), tf.dtypes.cast(tf.constant(0.0f), v.type())); createSlot(v.asOutput(), MG, mgInitializer); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java index b0fe48967dd..1c027cb5ddf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/CastHelper.java @@ -34,7 +34,7 @@ public class CastHelper { */ @SuppressWarnings("unchecked") public static Operand cast( - Ops tf, Operand value, Class requiredType) { + Ops tf, Operand value, Class requiredType) { return (value.type() == requiredType) ? (Operand) value : tf.dtypes.cast(value, requiredType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index 4ca2c789f28..e730c79cfbf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -14,8 +14,9 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; -import org.tensorflow.ndarray.NdArray; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TInt32; @@ -33,7 +34,9 @@ public class ShapeUtils { /** * Converts a shape operand to a Shape object * + * @param scope the TensorFlow scope * @param dims the Operand containing the shape values + * @param the date type for the shape dimensions. * @return a new Shape based on an Operand that contains dimensions */ public static Shape toShape(Scope scope, Operand dims) { @@ -45,8 +48,8 @@ public static Shape toShape(Scope scope, Operand dims) * Converts a TInt32 type Operand to a Java int array * * @param scope the TensorFlow scope - * @param dims the TInt32 Operand - * @return the int array + * @param dims the shape dimensions operand + * @return the int array of the dimensions */ public static int[] getIntArray(Scope scope, Operand dims) { long[] longDims = getLongArray(scope, dims); @@ -66,8 +69,8 @@ public static long[] getLongArray(Scope scope, Operand if (scope.env().isEager()) { return getLongArray(dims.asTensor()); } - try (Session session = new Session((Graph)scope.env()); - TIntegral tensor = (TIntegral)session.runner().fetch(dims).run().get(0)) { + try (Session session = new Session((Graph) scope.env()); + TIntegral tensor = (TIntegral) session.runner().fetch(dims).run().get(0)) { return getLongArray(tensor); } } @@ -76,20 +79,21 @@ public static long[] getLongArray(Scope scope, Operand * Converts a TInt32 or TInt64 to a java long array * * @param dims the dimension tensor + * @param the type of the dimensions, must either be TInt32 or TInt64 type * @return the long array * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ public static long[] getLongArray(T dims) { List result = new ArrayList<>(); if (dims instanceof TInt32) { - ((TInt32)dims).scalars().forEach(s -> result.add((long) s.getInt())); + ((TInt32) dims).scalars().forEach(s -> result.add((long) s.getInt())); } else if (dims instanceof TInt64) { - ((TInt64)dims).scalars().forEach(s -> result.add(s.getLong())); + ((TInt64) dims).scalars().forEach(s -> result.add(s.getLong())); } else if (dims instanceof TUint8) { - ((TUint8)dims).scalars().forEach(s -> result.add(s.getObject().longValue())); - } else { // shouldn't happen - throw new IllegalArgumentException("the data type must be an integer type"); - } + ((TUint8) dims).scalars().forEach(s -> result.add(s.getObject().longValue())); + } else { // shouldn't happen + throw new IllegalArgumentException("the data type must be an integer type"); + } return result.stream().mapToLong(i -> i).toArray(); } From 1dfb7c38068fcc9b65a88de512d199484c65e292 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Feb 2021 13:51:02 -0500 Subject: [PATCH 06/97] Update with new generic parameters --- .../org/tensorflow/framework/metrics/AUC.java | 1018 +++++++++++++++++ .../framework/metrics/AUCCurve.java | 36 + .../framework/metrics/AUCSummationMethod.java | 41 + .../framework/metrics/Accuracy.java | 89 ++ .../framework/metrics/BinaryAccuracy.java | 100 ++ .../metrics/CategoricalAccuracy.java | 85 ++ .../framework/metrics/FalseNegatives.java | 128 +++ .../framework/metrics/FalsePositives.java | 129 +++ .../tensorflow/framework/metrics/MeanIoU.java | 163 +++ .../framework/metrics/MeanRelativeError.java | 173 +++ .../framework/metrics/MeanTensor.java | 186 +++ .../tensorflow/framework/metrics/Metrics.java | 52 +- .../framework/metrics/Precision.java | 400 +++++++ .../framework/metrics/PrecisionAtRecall.java | 122 ++ .../tensorflow/framework/metrics/Recall.java | 426 +++++++ .../framework/metrics/RecallAtPrecision.java | 132 +++ .../metrics/RootMeanSquaredError.java | 87 ++ .../metrics/SensitivityAtSpecificity.java | 150 +++ .../metrics/SparseCategoricalAccuracy.java | 135 +++ .../SparseTopKCategoricalAccuracy.java | 70 ++ .../metrics/SpecificityAtSensitivity.java | 151 +++ .../org/tensorflow/framework/metrics/Sum.java | 60 + .../metrics/TopKCategoricalAccuracy.java | 70 ++ .../framework/metrics/TrueNegatives.java | 129 +++ .../framework/metrics/TruePositives.java | 128 +++ .../impl/ConfusionMatrixConditionCount.java | 186 +++ .../metrics/impl/ConfusionMatrixEnum.java | 57 + .../framework/metrics/impl/MetricsHelper.java | 487 +++++++- .../impl/SensitivitySpecificityBase.java | 277 +++++ .../framework/metrics/impl/SymbolicShape.java | 56 + .../metrics/impl/WeightsBroadcastOps.java | 186 +++ .../framework/utils/SparseTensor.java | 75 ++ .../tensorflow/framework/metrics/AUCTest.java | 324 ++++++ .../framework/metrics/AccuracyTest.java | 130 +++ .../framework/metrics/BinaryAccuracyTest.java | 177 +++ .../metrics/CategoricalAccuracyTest.java | 156 +++ .../framework/metrics/FalseNegativesTest.java | 141 +++ .../framework/metrics/FalsePositivesTest.java | 148 +++ .../framework/metrics/MeanIoUTest.java | 109 ++ .../metrics/MeanRelativeErrorTest.java | 100 ++ .../framework/metrics/MeanTensorTest.java | 119 ++ .../metrics/PrecisionAtRecallTest.java | 179 +++ .../framework/metrics/PrecisionTest.java | 339 ++++++ .../metrics/RecallAtPrecisionTest.java | 207 ++++ .../framework/metrics/RecallTest.java | 341 ++++++ .../metrics/RootMeanSquaredErrorTest.java | 72 ++ .../metrics/SensitivityAtSpecificityTest.java | 185 +++ .../metrics/SpecificityAtSensitivityTest.java | 184 +++ .../tensorflow/framework/metrics/SumTest.java | 113 ++ .../metrics/TopKCategoricalAccuracyTest.java | 103 ++ .../framework/metrics/TrueNegativesTest.java | 141 +++ .../framework/metrics/TruePositivesTest.java | 141 +++ .../metrics/impl/AssertBroadcastableTest.java | 1 + 53 files changed, 8988 insertions(+), 6 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java new file mode 100644 index 00000000000..62311c3cda5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -0,0 +1,1018 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.framework.metrics.impl.SymbolicShape; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.*; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. + * + *

This metric creates four local variables, truePositives`, trueNegatives`, + * falsePositives` and falseNegatives` that are used to compute the AUC. To discretize the AUC + * curve, a linearly spaced set of thresholds is used to compute pairs of recall and precision + * values. The area under the ROC-curve is therefore computed using the height of the recall values + * by the false positive rate, while the area under the PR-curve is the computed using the height of + * the precision values by the recall. + * + *

This value is ultimately returned as auc, an idempotent operation that computes the area + * under a discretized curve of precision versus recall values (computed using the aforementioned + * variables). The numThresholds variable controls the degree of discretization with larger + * numbers of thresholds more closely approximating the true AUC. The quality of the approximation + * may vary dramatically depending on numThresholds`. The thresholds parameter can be used to + * manually specify thresholds which split the predictions more evenly. + * + *

For best results, predictions should be distributed approximately uniformly in the range [0, + * 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if this is not + * the case. Setting summationMethod to minoring or majoring can help quantify the error in + * the approximation by providing lower or upper bound estimate of the AUC. + *

+ *

+ * Usage:
+ *

+ * AUC m = new  getTF().keras.metrics.AUC( getTF(), 3);
+ * m.updateState( getTF().constant(new float[] {0, 0, 1,1}),
+ *          getTF().constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
+ *
+ * // threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
+ * // tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
+ * // recall = [1, 0.5, 0], fpRate = [1, 0, 0]
+ * // auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75
+ * Operand<TFloat32> result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 0.75
+ * 
+ *
+ * m.resetStates()
+ * m.updateState( getTF().constant(new float[] {0, 0, 1, 1}),
+ *                 getTF().constant(new float[] {0f, 0.5f, 0.3f, 0.9f}, ),
+ *                 getTF().constant(new float[] {1, 0, 0, 1}));
+ * result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 1.0
+ * 
+ * + * @param The data type for the metric result + */ +public class AUC extends Metric { + + /** Default Fuzz factor. */ + public static final float EPSILON = 1e-7f; + + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; + public static final String TRUE_NEGATIVES = "TRUE_NEGATIVES"; + public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; + public static final int DEFAULT_NUM_THRESHOLDS = 200; + public static final String DEFAULT_NAME = "auc"; + + private final int numThresholds; + private final AUCCurve curve; + private final AUCSummationMethod summationMethod; + private final float[] thresholds; + private final boolean multiLabel; + private final String truePositivesName; + private final String falsePositivesName; + private final String trueNegativesName; + private final String falseNegativesName; + private final Map> initializers = new HashMap<>(); + private final Class type; + private Integer numLabels; + private Operand labelWeights; + private Variable truePositives; + private Variable falsePositives; + private Variable trueNegatives; + private Variable falseNegatives; + private boolean initialized; + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link #DEFAULT_NUM_THRESHOLDS} for the numThresholds, {@link AUCCurve#ROC} for the curve type, + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for + * thresholds, false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, long seed, Class type) { + this( + tf, + null, + DEFAULT_NUM_THRESHOLDS, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NUM_THRESHOLDS} for the + * numThresholds, {@link AUCCurve#ROC} for the curve type, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, null for thresholds, + * false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, long seed, Class type) { + this( + tf, + name, + DEFAULT_NUM_THRESHOLDS, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the + * summation method, null for thresholds, false for multiLabel, and + * null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, int numThresholds, long seed, Class type) { + this( + tf, + null, + numThresholds, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the + * summation method, null for numThresholds, false for multiLabel, and + * null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, float[] thresholds, long seed, Class type) { + this( + tf, + null, + null, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric. using {@link AUCCurve#ROC} for the curve type, + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for + * thresholds, false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, int numThresholds, long seed, Class type) { + this( + tf, + name, + numThresholds, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the summation + * method, {@link #DEFAULT_NUM_THRESHOLDS} num thresholds, false for multiLabel, and + * null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { + this( + tf, + name, + null, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link AUCSummationMethod#INTERPOLATION} for + * the summation method, null for thresholds, false for multiLabel, and + * null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, int numThresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + name, + numThresholds, + curve, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, {@link #DEFAULT_NUM_THRESHOLDS} num + * thresholds, false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + name, + null, + curve, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for + * thresholds, false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + null, + numThresholds, + curve, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, false for multiLabel, + * and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + null, + null, + curve, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric. using {@link #DEFAULT_NAME} for the metric name,, + * null for thresholds, false for multiLabel, and null for + * labelWeights. + * + * @param tf The TensorFlow Ops + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + int numThresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, null, numThresholds, curve, summationMethod, null, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * null for numThresholds, false for multiLabel, and null + * for labelWeights. + * + * @param tf The TensorFlow Ops + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + float[] thresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, null, null, curve, summationMethod, thresholds, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric. using null for thresholds, + * false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used, + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + String name, + int numThresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, name, numThresholds, curve, summationMethod, null, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric. using null> for the numThresholds, + * false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + String name, + float[] thresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, name, null, curve, summationMethod, thresholds, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS} + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein + * AUC is computed separately for each label and then averaged across labels, or (when false) + * if the data should be flattened into a single label before AUC computation. In the latter + * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an + * individual data point. Should be set to false for multi-class data. + * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When + * multi_label is True, the weights are applied to the individual label AUCs when they are + * averaged to produce the multi-label AUC. When it's false, they are used to weight the + * individual label predictions in computing the confusion matrix on the flattened data. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + * @throws IllegalArgumentException if numThresholds is less than 2 and thresholds is null, or if + * a threshold value is less than 0 or greater than 1. + */ + public AUC( + Ops tf, + String name, + Integer numThresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + float[] thresholds, + boolean multiLabel, + Operand labelWeights, + long seed, + Class type) { + super(tf, name == null ? DEFAULT_NAME : name, seed); + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); + this.trueNegativesName = this.getVariableName(TRUE_NEGATIVES); + this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); + this.curve = curve; + this.summationMethod = summationMethod; + this.type = type; + + this.multiLabel = multiLabel; + + if (thresholds != null) { // ignore numThresholds + for (float t : thresholds) + if (t < 0.0f || t > 1.0f) + throw new IllegalArgumentException( + String.format( + "Threshold values must be in [0, 1]. Invalid values: %s", + Arrays.toString(thresholds))); + this.numThresholds = thresholds.length + 2; + Arrays.sort(thresholds); + } else { + if (numThresholds <= 1) throw new IllegalArgumentException("numThresholds must be > 1."); + this.numThresholds = numThresholds; + thresholds = new float[numThresholds - 2]; + // linearly interpolate (numThresholds - 2) thresholds between endpoints + for (int i = 0; i < thresholds.length; i++) { + thresholds[i] = (i + 1) * 1.0f / (this.numThresholds - 1); + } + } + // Add an endpoint "threshold" below zero and above one for either + // threshold method to account for floating point imprecision. + if (thresholds.length != this.numThresholds - 2) + throw new IllegalArgumentException( + "Thresholds length must contain numThresholds - 2 entries"); + this.thresholds = new float[this.numThresholds]; + this.thresholds[0] = -EPSILON; + System.arraycopy(thresholds, 0, this.thresholds, 1, thresholds.length); + this.thresholds[this.numThresholds - 1] = 1 + EPSILON; + + if (labelWeights != null) { + // assert that labelWeights are non-negative. + + this.labelWeights = labelWeights; + Op checks = + getTF() + .withSubScope("AUC") + .assertThat( + getTF() + .math + .greaterEqual( + labelWeights, cast(getTF(), getTF().constant(0), labelWeights.type())), + Collections.singletonList( + getTF().constant("All values of `labelWeights` must be non-negative."))); + + Ops ltf = + getTF() + .withSubScope("updateState") + .withControlDependencies(Collections.singletonList(checks)); + + this.labelWeights = ltf.identity(this.labelWeights); + } + + if (this.multiLabel) { + this.numLabels = null; + } + } + + /** + * Initialize truePositives, falsePositives, trueNegatives, and falseNegatives variables, given + * the shape of the data. + * + * @param shape the prediction shape if called from updateState, otherwise null + */ + @SuppressWarnings("unchecked") + private Map> build(Shape shape) { + Shape variableShape; + if (initialized) { + return Collections.EMPTY_MAP; + } + + if (this.isMultiLabel()) { + if (shape == null) { + throw new IllegalArgumentException("For multiLabel, a shape must be provided"); + } + if (shape.numDimensions() != 2) + throw new IllegalArgumentException( + String.format( + "labels must have rank=2 when multiLabel is true. Found rank %d.", + shape.numDimensions())); + this.numLabels = (int) shape.size(1); + variableShape = Shape.of(this.numThresholds, this.numLabels); + } else { + variableShape = Shape.of(this.numThresholds); + } + + Zeros zeros = new Zeros<>(getTF()); + Operand zero = zeros.call(getTF().constant(variableShape), type); + if (truePositives == null) { + truePositives = getTF().withName(getTruePositivesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, getTF().assign(truePositives, zero)); + } + + if (falsePositives == null) { + falsePositives = getTF().withName(getFalsePositivesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.FALSE_POSITIVES, getTF().assign(falsePositives, zero)); + } + + if (trueNegatives == null) { + trueNegatives = getTF().withName(getTrueNegativesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.TRUE_NEGATIVES, getTF().assign(trueNegatives, zero)); + } + + if (falseNegatives == null) { + falseNegatives = getTF().withName(getFalseNegativesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, getTF().assign(falseNegatives, zero)); + } + + this.initialized = true; + return initializers; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + + Operand lLabels = cast(getTF(), labels, type); + Operand lPredictions = cast(getTF(), predictions, type); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + List updateOperations = new ArrayList<>(); + Map> varInitializers = Collections.EMPTY_MAP; + if (!this.initialized) { + varInitializers = build(lPredictions.shape()); + } + if (this.isMultiLabel() || this.getLabelWeights() != null) { + List> symbols = new ArrayList<>(); + symbols.add(new SymbolicShape<>(lLabels, "N", "L")); + if (this.isMultiLabel()) { + symbols.add(new SymbolicShape<>(this.truePositives, "T", "L")); + symbols.add(new SymbolicShape<>(this.falsePositives, "T", "L")); + symbols.add(new SymbolicShape<>(this.trueNegatives, "T", "L")); + symbols.add(new SymbolicShape<>(this.falseNegatives, "T", "L")); + } + if (this.getLabelWeights() != null) { + symbols.add(new SymbolicShape<>(this.getLabelWeights(), "L", "")); + } + updateOperations.addAll( + MetricsHelper.assertShapes(getTF(), symbols, "Number of labels is not consistent.")); + } + if (this.isMultiLabel()) { + this.labelWeights = null; + } + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.falsePositives); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.trueNegatives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives); + + updateOperations.addAll( + MetricsHelper.updateConfusionMatrixVariables( + getTF(), + confusionMatrix, + varInitializers, + lLabels, + lPredictions, + this.thresholds, + null, + null, + tSampleWeights, + this.isMultiLabel(), + this.getLabelWeights())); + return updateOperations; + } + + /** + * Interpolation formula inspired by section 4 of Davis & Goadrich 2006. + * + * @return an approximation of the area under the P-R curve. + */ + private Operand interpolatePRAuc() { + // truePositives[:self.numThresholds - 1] + Operand tp0 = + getTF() + .slice( + truePositives, + getTF().constant(new int[] {0}), + getTF().constant(new int[] {this.getNumThresholds() - 1})); + // truePositives[1:] + Operand tp1 = + getTF() + .slice( + truePositives, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1})); + + Operand dTP = getTF().math.sub(tp0, tp1); + + Operand p = getTF().math.add(truePositives, falsePositives); + + Operand dP = + getTF() + .math + .sub( + getTF() + .slice( + p, + getTF().constant(new int[] {0}), + getTF().constant(new int[] {this.getNumThresholds() - 1})), + getTF() + .slice(p, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1}))); + + Operand precisionSlope = + getTF() + .math + .divNoNan( + dTP, getTF().math.maximum(dP, getTF().dtypes.cast(getTF().constant(0), dP.type()))); + + Operand intercept = + getTF() + .math + .sub( + getTF() + .slice( + truePositives, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})), + getTF() + .math + .mul( + precisionSlope, + getTF() + .slice( + p, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})))); + + Operand safePRatio = + getTF() + .select( + getTF() + .math + .logicalAnd( + getTF() + .math + .greater( + getTF() + .slice( + p, + getTF().constant(new int[] {0}), + getTF().constant(new int[] {this.getNumThresholds() - 1})), + getTF().dtypes.cast(getTF().constant(0), p.type())), + getTF() + .math + .greater( + getTF() + .slice( + p, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})), + getTF().dtypes.cast(getTF().constant(0), p.type()))), + getTF() + .math + .divNoNan( + getTF() + .slice( + p, + getTF().constant(new int[] {0}), + getTF().constant(new int[] {this.getNumThresholds() - 1})), + getTF() + .math + .maximum( + getTF() + .slice( + p, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})), + getTF().dtypes.cast(getTF().constant(0), p.type()))), + getTF() + .onesLike( + getTF() + .slice( + p, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})))); + + Operand fn1 = + getTF() + .slice( + falseNegatives, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1})); + + Operand aucTotalPos = + getTF() + .math + .mul( + precisionSlope, + getTF().math.add(dTP, getTF().math.mul(intercept, getTF().math.log(safePRatio)))); + + Operand prAucIncrement = + getTF() + .math + .divNoNan( + aucTotalPos, + getTF() + .math + .maximum( + getTF().math.add(tp1, fn1), + getTF().dtypes.cast(getTF().constant(0), this.truePositives.type()))); + + if (this.isMultiLabel()) { + Operand byLabelAuc = getTF().reduceSum(prAucIncrement, getTF().constant(0)); + if (this.getLabelWeights() == null) { + return MetricsHelper.mean(getTF(), byLabelAuc); + } else { + return getTF() + .math + .divNoNan( + getTF() + .reduceSum( + getTF().math.mul(byLabelAuc, this.getLabelWeights()), + allAxes(getTF(), byLabelAuc)), + getTF().reduceSum(getLabelWeights(), allAxes(getTF(), getLabelWeights()))); + } + } else { + return getTF().reduceSum(prAucIncrement, allAxes(getTF(), prAucIncrement)); + } + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + + if (this.getCurve() == AUCCurve.PR + && this.getSummationMethod() == AUCSummationMethod.INTERPOLATION) { + return this.interpolatePRAuc(); + } + Ops tf = getTF(); + Operand x; + Operand y; + Operand recall = + getTF().math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + + if (this.getCurve() == AUCCurve.ROC) { + x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives)); + y = recall; + } else { // AUCCurve.PR + y = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + x = recall; + } + + // Find the rectangle heights based on `summationMethod`. + // y[:self.numThresholds - 1] + Operand ySlice1 = + tf.slice( + y, tf.constant(new int[] {0}), tf.constant(new int[] {this.getNumThresholds() - 1})); + // y[1:] + Operand ySlice2 = tf.slice(y, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + + Operand heights = null; + switch (this.getSummationMethod()) { + case INTERPOLATION: + heights = + tf.math.div(tf.math.add(ySlice1, ySlice2), tf.dtypes.cast(tf.constant(2), y.type())); + break; + case MINORING: + heights = tf.math.minimum(ySlice1, ySlice2); + break; + case MAJORING: + heights = tf.math.maximum(ySlice1, ySlice2); + break; + } + + if (this.isMultiLabel()) { + Operand riemannTerms = + tf.math.mul( + tf.math.sub( + tf.slice( + x, + tf.constant(new int[] {0}), + tf.constant(new int[] {this.getNumThresholds() - 1})), + tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))), + heights); + Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0)); + + if (this.getLabelWeights() == null) { + return MetricsHelper.mean(tf, byLabelAuc); + } else { + return tf.math.divNoNan( + tf.reduceSum( + tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())), + tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))); + } + + } else { + Operand slice1 = + tf.slice( + x, tf.constant(new int[] {0}), tf.constant(new int[] {this.getNumThresholds() - 1})); + Operand slice2 = tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand sub = tf.math.sub(slice1, slice2); + Operand operand = tf.math.mul(sub, heights); + return tf.reduceSum(operand, allAxes(tf, operand)); + } + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + List updateOperations = new ArrayList<>(initializers.values()); + return getTF().withSubScope("resetStates").withControlDependencies(updateOperations).noOp(); + } + + /** @return the numThresholds */ + public int getNumThresholds() { + return numThresholds; + } + + /** @return the curve */ + public AUCCurve getCurve() { + return curve; + } + + /** @return the summationMethod */ + public AUCSummationMethod getSummationMethod() { + return summationMethod; + } + + /** @return the thresholds */ + public float[] getThresholds() { + return thresholds; + } + + /** @return the multiLabel */ + public boolean isMultiLabel() { + return multiLabel; + } + + /** @return the numLabels */ + public Integer getNumLabels() { + return numLabels; + } + + /** @param numLabels the numLabels to set */ + public void setNumLabels(Integer numLabels) { + this.numLabels = numLabels; + } + + /** @return the labelWeights */ + public Operand getLabelWeights() { + return labelWeights; + } + + /** @return the truePositives */ + public Variable getTruePositives() { + return truePositives; + } + + /** @return the falsePositives */ + public Variable getFalsePositives() { + return falsePositives; + } + + /** @return the trueNegatives */ + public Variable getTrueNegatives() { + return trueNegatives; + } + + /** @return the falseNegatives */ + public Variable getFalseNegatives() { + return falseNegatives; + } + + /** @return the truePositivesName */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** @return the falsePositivesName */ + public String getFalsePositivesName() { + return falsePositivesName; + } + + /** @return the trueNegativesName */ + public String getTrueNegativesName() { + return trueNegativesName; + } + + /** @return the falseNegativesName */ + public String getFalseNegativesName() { + return falseNegativesName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java new file mode 100644 index 00000000000..b5426a0dd8f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +/** + * Specifies the type of the curve to be computed, {@link #ROC} for a Receiver Operator + * Characteristic curve [default] or {@link #PR} for a Precision-Recall-curve. + */ +public enum AUCCurve { + /** Receiver Operator Characteristic curve */ + ROC, + /** Precision-Recall-curve */ + PR; + + /** + * Gets the AUCCurve enum value by name, regardless of case + * + * @param name the name of the AUCCurve enum value. + * @return the AUCCurve enum value. + */ + public AUCCurve get(String name) { + return AUCCurve.valueOf(name.toUpperCase()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java new file mode 100644 index 00000000000..09581c726d3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +/** + * Specifies the Riemann summation method used. {@link #INTERPOLATION} (default) applies mid-point + * summation scheme for ROC. For PR-AUC, interpolates (true/false) positives but not the ratio that + * is precision (see Davis & Goadrich 2006 for details); {@link #MINORING} applies left summation + * for increasing intervals and right summation for decreasing intervals; {@link #MAJORING} does the + * opposite. + * + * @see Davis & Goadrich. 2006 + * @see Riemann summation method + */ +public enum AUCSummationMethod { + INTERPOLATION, + MAJORING, + MINORING; + + /** + * Gets the AUCSummationMethod enum value by name, regardless of case + * + * @param name the name of the AUCSummationMethod enum value. + * @return the AUCSummationMethod enum value. + */ + public AUCSummationMethod get(String name) { + return AUCSummationMethod.valueOf(name.toUpperCase()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java new file mode 100644 index 00000000000..f69170e57b9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that calculates how often predictions equals labels. + * + *

This metric creates two local variables, total and count that are used to compute the + * frequency with which predictions matches labels. This frequency is + * ultimately returned as binary accuracy: an idempotent operation that simply divides total by + * count. + * + *

If sampleWeights is null>, weights default to 1. Use sampleWeights of 0 to mask + * values. + * + * @param The data type for the metric result + */ +public class Accuracy extends MeanMetricWrapper implements LossMetric { + + /** + * Creates an Accuracy Metric using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Accuracy(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates an Accuracy Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Accuracy(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + LossTuple tuple = + MetricsHelper.raggedAssertCompatibleAndGetFlatValues(getTF(), tLabels, tPredictions); + tLabels = tuple.getLabels(); + tPredictions = tuple.getTarget(); + + if (!predictions.shape().isCompatibleWith(labels.shape())) { + throw new IllegalArgumentException( + String.format( + "Shapes %s and %s are incompatible", + predictions.shape().toString(), labels.shape().toString())); + } + + // cast TBool to result type + return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java new file mode 100644 index 00000000000..9e7f0f874cc --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that calculates how often predictions matches binary labels. + * + *

This metric creates two local variables, total and count that are used to compute the + * frequency with which predictions matches labels. This frequency is + * ultimately returned as binary accuracy: an idempotent operation that simply divides total by + * count. + * + *

If sampleWeights is null>, weights default to 1. Use sampleWeights of 0 to mask + * values. + * + * @param The data type for the metric result + */ +public class BinaryAccuracy extends MeanMetricWrapper + implements LossMetric { + /** the default threshold value for deciding whether prediction values are 1 or 0 */ + public static final float DEFAULT_THRESHOLD = 0.5f; + + /** the threshold value for deciding whether prediction values are 1 or 0 */ + private final float threshold; + + /** + * Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name and + * {@link #DEFAULT_THRESHOLD} for the threshold value. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryAccuracy(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold for deciding whether prediction values are 1 or 0 + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryAccuracy(Ops tf, float threshold, long seed, Class type) { + this(tf, null, threshold, seed, type); + } + + /** + * Creates a BinaryAccuracy Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold for deciding whether prediction values are 1 or 0 + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryAccuracy(Ops tf, String name, float threshold, long seed, Class type) { + super(tf, name, seed, type); + this.threshold = threshold; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions) { + + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Operand thresholdCast = cast(getTF(), getTF().constant(threshold), getResultType()); + tPredictions = + cast(getTF(), getTF().math.greater(tPredictions, thresholdCast), getResultType()); + Operand tLabels = cast(getTF(), labels, getResultType()); + return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java new file mode 100644 index 00000000000..c0635746d4d --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that calculates how often predictions matches one-hot labels. + * + *

You can provide logits of classes as predictionsy_pred, since argmax + * of logits and probabilities are same. + * + *

This metric creates two local variables, total and count that are + * used to compute the frequency with which predictions matches labels. + * This frequency is ultimately returned as categorical accuracy: an idempotent operation that + * simply divides total by count. + * + *

predictions and labels should be passed in as vectors of + * probabilities, rather than as labels. If necessary, use {@link + * org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand + * labels as a vector. + * + *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. + * + * @param The data type for the metric result + */ +public class CategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a CategoricalAccuracy metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public CategoricalAccuracy(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a CategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public CategoricalAccuracy(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + super.setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand trueMax = getTF().math.argMax(labels, getTF().constant(-1)); + + Operand predMax = getTF().math.argMax(predictions, getTF().constant(-1)); + return cast(getTF(), getTF().math.equal(trueMax, predMax), getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java new file mode 100644 index 00000000000..cf6f84af512 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of false negatives. + * + *

If sampleWeights is given, calculates the sum of the weights of false negatives. + * This metric creates one local variable, accumulator that is used to keep track of + * the number of false negatives. + * + *

If sampleWeightsnull + * sampleWeights The data type for the metric result + */ +public class FalseNegatives + extends ConfusionMatrixConditionCount { + + /** + * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a FalseNegatives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalseNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalseNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.FALSE_NEGATIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java new file mode 100644 index 00000000000..629caaafb52 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of false positives. + * + *

If sampleWeights is given, calculates the sum of the weights of false positives. + * This metric creates one local variable, accumulator that is used to keep track of + * the number of false positives. + * + *

If sampleWeightsnull + * sampleWeights The data type for the metric result + */ +public class FalsePositives< T extends TNumber> + extends ConfusionMatrixConditionCount { + + /** + * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a FalsePositives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalsePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalsePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.FALSE_POSITIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java new file mode 100644 index 00000000000..c8205565802 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -0,0 +1,163 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the mean Intersection-Over-Union metric. + * + *

Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, + * which first computes the IOU for each semantic class and then computes the average over classes. + * IOU is defined as follows: IOU = true_positive + * / (true_positive + false_positive + false_negative). The predictions are accumulated in a + * confusion matrix, weighted by sample_weight and the metric is then calculated from it. + * + *

If sampleWeight is null, weights default to 1. Use sample_weight of 0 to mask + * values. + * + * @param The data type for the metric result + */ +public class MeanIoU extends Metric { + + public static final String TOTAL_CONFUSION_MATRIX = "TOTAL_CONFUSION_MATRIX"; + private final String totalCMName; + private final Class type; + /** + * The possible number of labels the prediction task can have. This value must be provided, since + * a confusion matrix of dimension = [numClasses, numClasses] will be allocated. + */ + private final long numClasses; + + private Variable totalConfusionMatrix; + private Assign initializer; + + /** + * Creates a metric MeanIoU, using name as {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param numClasses The possible number of labels the prediction task can have + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + protected MeanIoU(Ops tf, long numClasses, long seed, Class type) { + this(tf, null, numClasses, seed, type); + } + + /** + * create a metric with reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param numClasses The possible number of labels the prediction task can have + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + protected MeanIoU(Ops tf, String name, long numClasses, long seed, Class type) { + super(tf, name, seed); + this.type = type; + this.totalCMName = this.getVariableName(TOTAL_CONFUSION_MATRIX); + this.numClasses = numClasses; + init(); + } + + private void init() { + Shape variableShape = Shape.of(numClasses, numClasses); + + if (totalConfusionMatrix == null) { + Zeros zeros = new Zeros<>(getTF()); + totalConfusionMatrix = + getTF().withName(totalCMName).variable(zeros.call(getTF().constant(variableShape), type)); + initializer = + getTF().assign(totalConfusionMatrix, zeros.call(getTF().constant(variableShape), type)); + } + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return initializer; + } + + /** + * Gets the initializer for the totalConfusionMatrix variable + * + * @return the initializer for the totalConfusionMatrix variable + */ + public Assign getInitializer() { + return initializer; + } + + /** {@inheritDoc} */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + + Operand tLabels = cast(getTF(), labels, type); + if (tLabels.shape().numDimensions() > 1) tLabels = getTF().shape.flatten(tLabels); + Operand tPredictions = cast(getTF(), predictions, type); + if (tPredictions.shape().numDimensions() > 1) + tPredictions = getTF().shape.flatten(tPredictions); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) + tSampleWeights = getTF().shape.flatten(tSampleWeights); + + Operand currentCM = + MetricsHelper.confusionMatrix( + getTF(), tLabels, tPredictions, getTF().constant(numClasses), tSampleWeights, type); + return Collections.singletonList(getTF().assignAdd(totalConfusionMatrix, currentCM)); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + Operand sumOverRow = tf.reduceSum(totalConfusionMatrix, tf.constant(0)); + Operand sumOverCol = tf.reduceSum(totalConfusionMatrix, tf.constant(1)); + Operand truePositives = + tf.linalg.matrixDiagPart( + totalConfusionMatrix, + tf.constant(0), + cast(tf, tf.constant(0), totalConfusionMatrix.type())); + Operand denominator = tf.math.add(sumOverRow, tf.math.sub(sumOverCol, truePositives)); + Operand numValidEntries = + tf.reduceSum( + tf.dtypes.cast( + tf.math.notEqual(denominator, cast(tf, tf.constant(0), denominator.type())), type), + allAxes(tf, denominator)); + Operand iou = tf.math.divNoNan(truePositives, denominator); + + Operand iouSum = tf.reduceSum(iou, allAxes(tf, iou)); + return tf.math.divNoNan(iouSum, numValidEntries); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java new file mode 100644 index 00000000000..eb8ccaf76d2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -0,0 +1,173 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the mean relative error by normalizing with the given values. + * + *

This metric creates two local variables, total and count that are + * used to compute the mean relative error. This is weighted by sampleWeight, and it is + * ultimately returned as mean relative error: an idempotent operation that simply divides total by + * count. + * + *

If sampleWeight is null, weights default to 1. Use sample_weight of + * 0 to mask * values. + * + * @param The data type for the metric result + */ +public class MeanRelativeError extends Mean { + private Operand normalizer; + + /** + * create a metric with name = class name and reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param normalizer The normalizer values with same shape as predictions. + */ + protected MeanRelativeError(Ops tf, float[] normalizer, long seed, Class type) { + this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * create a metric with reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, String name, float[] normalizer, long seed, Class type) { + this(tf, name, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * Creates a MeanRelativeError metric with name = class name and reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, double[] normalizer, long seed, Class type) { + this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * create a metric with reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, String name, double[] normalizer, long seed, Class type) { + this(tf, name, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * create a metric with name = class name and reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, Operand normalizer, long seed, Class type) { + this(tf, null, normalizer, seed, type); + } + + /** + * create a metric + * + * @param tf the TensorFlow ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError( + Ops tf, String name, Operand normalizer, long seed, Class type) { + super(tf, name, seed, type); + this.normalizer = normalizer; + } + + /** {@inheritDoc} */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Operand tLabels = cast(getTF(), labels, getResultType()); + if (tLabels.shape().numDimensions() > 1) tLabels = getTF().shape.flatten(tLabels); + + Operand tPredictions = cast(getTF(), predictions, getResultType()); + if (tPredictions.shape().numDimensions() > 1) + tPredictions = getTF().shape.flatten(tPredictions); + + LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); + tPredictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + Operand tSampleWeights = + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { + tSampleWeights = getTF().shape.flatten(tSampleWeights); + } + + tuple = LossesHelper.removeSqueezableDimensions(getTF(), normalizer, tPredictions); + normalizer = tuple.getLabels(); + tPredictions = tuple.getTarget(); + + if (!tPredictions.shape().isCompatibleWith(tLabels.shape())) + throw new IllegalArgumentException( + String.format( + "Prediction shape %s is not compatible with labels shape %s", + tPredictions.shape(), tLabels.shape())); + + Operand relativeErrors = + getTF() + .math + .divNoNan( + getTF().math.abs(getTF().math.sub(tLabels, tPredictions)), this.getNormalizer()); + + return super.updateStateList(relativeErrors, tSampleWeights); + } + + /** + * Gets the normalizer Operand + * + * @return the normalizer + */ + public Operand getNormalizer() { + return normalizer; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java new file mode 100644 index 00000000000..d9c767965a6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -0,0 +1,186 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.framework.metrics.impl.WeightsBroadcastOps; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that computes the element-wise (weighted) mean of the given tensors. + * + * @param The data type for the metric result + */ +public class MeanTensor extends Metric { + public static final String TOTAL = "total"; + public static final String COUNT = "count"; + private final String totalName; + private final String countName; + private final Class type; + private Shape shape; + private Variable total; + private Variable count; + private Assign totalInitializer; + private Assign countInitializer; + private boolean initialized; + + /** + * Creates a MeanTensor metric, using {@link Class#getSimpleName()} as the name + * + * @param tf the TensorFlow ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public MeanTensor(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + /** + * Creates a MeanTensor metric + * + * @param tf the TensorFlow ops + * @param name the name of this metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public MeanTensor(Ops tf, String name, long seed, Class type) { + super(tf, name, seed); + this.type = type; + this.totalName = this.getVariableName(TOTAL); + this.countName = this.getVariableName(COUNT); + } + + /** + * Creates the Operations that initialize the total and count variables. + * + * @param shape the shape of the variables + * @return true if the variables need initialization, otherwise false; + */ + private boolean init(Shape shape) { + if (!initialized) { + this.shape = shape; + Zeros zeros = new Zeros<>(getTF()); + Operand zero = zeros.call(getTF().constant(shape), type); + + if (total == null) { + total = getTF().withName(totalName).variable(zero); + totalInitializer = getTF().assign(total, zero); + } + if (count == null) { + count = getTF().withName(countName).variable(zero); + countInitializer = getTF().assign(count, zero); + } + this.initialized = true; + return true; + } else { + return false; + } + } + + /** {@inheritDoc */ + @Override + public List updateStateList( + Operand values, Operand sampleWeights) { + Ops tf = getTF(); + Operand tValues = cast(tf, values, type); + Operand tSampleWeights = null; + if (sampleWeights != null) tSampleWeights = cast(tf, sampleWeights, type); + + boolean needsInitialization = init(values.shape()); + + if (!this.shape.equals(values.shape())) { + throw new IllegalArgumentException( + String.format( + "MeanTensor input values must always have the same shape. Expected shape (set during the first call): %s. Got %s", + this.shape.toString(), values.shape().toString())); + } + + Operand numValues = tf.onesLike(tValues); + if (tSampleWeights != null) { + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); + try { + tSampleWeights = WeightsBroadcastOps.broadcastWeights(tf, tSampleWeights, tValues); + } catch (IllegalArgumentException ex) { + int ndim = values.shape().numDimensions(); + int weightNdim = tSampleWeights.asOutput().shape().numDimensions(); + int[] range = new int[ndim - weightNdim]; + for (int i = weightNdim; i < ndim; i++) { + range[i] = i; + } + tValues = tf.math.mean(tValues, tf.constant(range)); + } + numValues = tf.math.mul(numValues, tSampleWeights); + tValues = tf.math.mul(tValues, tSampleWeights); + } + + List controlOpsPre = new ArrayList<>(); + if (needsInitialization) { + controlOpsPre.add(countInitializer); + controlOpsPre.add(totalInitializer); + } + Ops tf1 = tf.withSubScope("variables").withControlDependencies(controlOpsPre); + + List controlOps = new ArrayList<>(); + controlOps.add(tf1.assignAdd(this.count, numValues)); + controlOps.add(tf1.assignAdd(this.total, tValues)); + return controlOps; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + if (!this.initialized) { + throw new IllegalStateException( + "MeanTensor does not have any result yet. Please use `.update_state(value)` before retrieving the result."); + } + return getTF().math.divNoNan(total, count); + } + + /** @return the total */ + public Variable getTotal() { + return total; + } + + /** @return the count */ + public Variable getCount() { + return count; + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + List controlOpsPre = new ArrayList<>(); + controlOpsPre.add(countInitializer); + controlOpsPre.add(totalInitializer); + return getTF().withSubScope("resetStates").withControlDependencies(controlOpsPre).noOp(); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 95b74bf1eea..e4cc9c3aa3d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -16,15 +16,15 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; /** Helper class with built-in metrics functions. */ public class Metrics { - public static final float L2_NORM_EPSILON = 1e-12f; - /** * Computes how often targets are in the top K predictions. * @@ -55,4 +55,52 @@ public static Operand topKCategoricalAccuracy( tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); } + + /** + * Computes how often integer targets are in the top K predictions. + * + *

Standalone usage: + * + *

+   *     Operand<TInt32> labels = tf.constant(new int[]{2, 1});
+   *     Operand<TFloat32> predictions = tf.constant(new float[][]
+   *                            {{0.1f, 0.9f, 0.f8}, {0.05f, 0.95f, 0f}});
+   *     Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
+   *                                    labels, predictions, 3)
+   *     //m.shape().toString == "[2]"
+   * 
+ * + * @param tf the TensorFlow Ops. + * @param labels the ground truth values. + * @param predictions The prediction values. + * @param k Number of top elements to look at for computing accuracy. + * @param the data type for the predictions and results + * @param the data type ofr the labels. + * @return the Operand for the Sparse top K categorical accuracy value. + */ + @SuppressWarnings("unchecked") + public static Operand sparseTopKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, int k) { + Operand tLabels; + if (labels.type() != predictions.type()) + tLabels = CastHelper.cast(tf, labels, predictions.type()); + else tLabels = (Operand) labels; + + int predictionsRank = predictions.shape().numDimensions(); + int labelsRank = tLabels.shape().numDimensions(); + + Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); + if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { + if (predictionsRank > 2) { + castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); + } + if (labelsRank > 1) { + tLabels = tf.shape.flatten(tLabels); + } + } + return CastHelper.cast( + tf, + tf.nn.inTopK(castPredictions, CastHelper.cast(tf, tLabels, TInt32.class), tf.constant(k)), + predictions.type()); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java new file mode 100644 index 00000000000..6b70c6680cb --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -0,0 +1,400 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.*; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the precision of the predictions with respect to the labels. + * + *

The metric creates two local variables, truePositives and falsePositives that are used to + * compute the precision. This value is ultimately returned as precision, an idempotent operation + * that simply divides truePositives by the sum of truePositives and falsePositives. + * + *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values. + * + *

If is set, the metric calculates precision as how often on average a class among the top-k + * classes with the highest predicted values of a batch entry is correct and can be found in the + * label for that entry. + * + *

If classId is specified, the metric calculates precision by considering only the entries in the batch + * for which classId is above the thresholds and/or in the top-k highest predictions, and computing + * the fraction of them for which classId is indeed a correct label. + * + * @param The data type for the metric result + */ +public class Precision extends Metric { + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; + public static final float DEFAULT_THRESHOLD = 0.5f; + + private final float[] thresholds; + private final Integer topK; + private final Integer classId; + private final String truePositivesName; + private final String falsePositivesName; + private final Class type; + private Variable truePositives; + private Variable falsePositives; + private final List initializers = new ArrayList<>(); + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId + * values and with a threshold of {@link #DEFAULT_THRESHOLD).} + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, long seed, Class type) { + this(tf, null, new float[] {DEFAULT_THRESHOLD}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with no topK or classId values with a threshold of {@link + * #DEFAULT_THRESHOLD).} + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, String name, long seed, Class type) { + this(tf, name, new float[] {DEFAULT_THRESHOLD}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId + * values. + * + * @param tf the TensorFlow Ops + * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId + * values. + * + * @param tf the TensorFlow Ops + * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, null, null, seed, type); + } + + /** + * Creates a Precision Metric with no topK or classId values. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with no topK or classId values. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, String name, float[] thresholds, long seed, Class type) { + this(tf, name, thresholds, null, null, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, float threshold, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, thresholds, topK, classId, seed, type); + } + + /** + * Creates a Precision Metric. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, + String name, + float threshold, + Integer topK, + Integer classId, + long seed, + Class type) { + this(tf, name, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Precision Metric. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, + String name, + float[] thresholds, + Integer topK, + Integer classId, + long seed, + Class type) { + super(tf, name, seed); + this.type = type; + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); + float defaultThreshold = topK == null ? DEFAULT_THRESHOLD : MetricsHelper.NEG_INF; + this.thresholds = thresholds == null ? new float[] {defaultThreshold} : thresholds; + this.topK = topK; + this.classId = classId; + + init(); + } + + /** Initializes the variables */ + private void init() { + Ops tf = getTF(); + Zeros zeros = new Zeros<>(tf); + Operand zero = zeros.call(tf.constant(Shape.of(thresholds.length)), type); + + if (this.truePositives == null) { + this.truePositives = + tf.withName(truePositivesName) + .variable(zero); + initializers.add(tf.assign(truePositives, zero)); + + } + if (this.falsePositives == null) { + this.falsePositives = + tf.withName(falsePositivesName) + .variable(zeros.call(tf.constant(Shape.of(thresholds.length)), type)); + initializers.add(tf.assign(falsePositives, zero)); + } + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, falsePositives); + + Operand tPredictions = cast(getTF(), predictions, type); + Operand tLabels = cast(getTF(), labels, type); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + + return new ArrayList( + MetricsHelper.updateConfusionMatrixVariables( + getTF(), + confusionMatrix, + Collections.EMPTY_MAP, + tLabels, + tPredictions, + thresholds, + topK, + classId, + tSampleWeights, + false, + null)); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + Operand result = + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + return thresholds.length == 1 + ? tf.slice( + result, + tf.expandDims(tf.constant(0), tf.constant(0)), + tf.expandDims(tf.constant(1), tf.constant(0))) + : result; + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); + } + + /** + * Gets the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return thresholds; + } + + /** + * Gets the topK value, may be null + * + * @return the topK + */ + public Integer getTopK() { + return topK; + } + + /** + * Gets the classId, may be null + * + * @return the classId + */ + public Integer getClassId() { + return classId; + } + + /** + * Gets the truePositives variable + * + * @return the truePositives + */ + public Variable getTruePositives() { + return truePositives; + } + + /** Gets the falsePositives variable return the falsePositives */ + public Variable getFalsePositives() { + return falsePositives; + } + + /** + * Gets the name of the truePositives variable + * + * @return the truePositivesName + */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** + * Gets the name of the falsePositives variable + * + * @return the falsePositivesName + */ + public String getFalsePositivesName() { + return falsePositivesName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java new file mode 100644 index 00000000000..2ec66df0ca9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best precision where recall is >= specified value. + * @param The data type for the metric result + */ +public class PrecisionAtRecall + extends SensitivitySpecificityBase { + + private final float recall; + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()} and {@link + * #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param recall the recall. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public PrecisionAtRecall(Ops tf, float recall, long seed, Class type) { + this(tf, null, recall, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number of + * thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param recall the recall. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public PrecisionAtRecall(Ops tf, String name, float recall, long seed, Class type) { + this(tf, name, recall, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param recall the recall. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Class type) { + this(tf, null, recall, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param recall the recall. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public PrecisionAtRecall( + Ops tf, String name, float recall, int numThresholds, long seed, Class type) { + super(tf, name, recall, numThresholds, seed, type); + if (recall < 0f || recall > 1f) + throw new IllegalArgumentException("recall must be in the range [0, 1]."); + this.recall = recall; + } + + @Override + public Operand result() { + Ops tf = getTF(); + + Operand recall = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + Operand sub = tf.math.sub(recall, cast(tf, tf.constant(value), getType())); + Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); + minIndex = tf.expandDims(minIndex, tf.constant(0)); + + Operand trueSlice = tf.slice(this.truePositives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(this.falsePositives, minIndex, tf.constant(new int[] {1})); + return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + } + + /** @return the recall */ + public float getRecall() { + return recall; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java new file mode 100644 index 00000000000..0672b78f229 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -0,0 +1,426 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.*; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the recall of the predictions with respect to the labels. + *

This metric creates two local + * variables, truePositives and falseNegatives, that are used to compute the recall. This value is + * ultimately returned as recall, an idempotent operation that simply divides truePositives by the sum of truePositives and falseNegatives. + * + *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values. + * + *

If is set, the metric calculates recall as how often on average a class among the labels of a + * batch entry is in the top-k predictions. + * + *

If classId is specified, the metric calculates recall by considering only the entries in the batch + * for which classId is in the label, and computing the fraction of them for which classId is above + * the threshold and/or in the top-k predictions. + * + * @param The data type for the metric result + */ +public class Recall< T extends TNumber> extends Metric< T> { + public static final float DEFAULT_THRESHOLD = 0.5f; + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; + + private final float[] thresholds; + private final Integer topK; + private final Integer classId; + private final String truePositivesName; + private final String falseNegativesName; + private final Class type; + private Variable truePositives; + private Variable falseNegatives; + private final List initializers = new ArrayList<>(); + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set + * to null, and thresholds set to {@link #DEFAULT_THRESHOLD} + * + * @param tf The TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, long seed, Class type) { + this(tf, null, null, null, null, seed, type); + } + + /** + * Creates a Recall metric with topK and classId set to null and thresholds set to {@link + * #DEFAULT_THRESHOLD}. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, long seed, Class type) { + this(tf, name, null, null, null, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set + * to null. + * + * @param tf The TensorFlow Ops + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, float threshold, long seed, Class type) { + this(tf, null, threshold, null, null, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set + * to null. + * + * @param tf The TensorFlow Ops + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, null, null, seed, type); + } + + /** + * Creates a Recall metric with topK and classId set to null. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, threshold, null, null, seed, type); + } + + /** + * Creates a Recall metric with topK and classId set to null. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, float[] thresholds, long seed, Class type) { + this(tf, name, thresholds, null, null, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()} and using a threshold + * value of {@link #DEFAULT_THRESHOLD}. + * + * @param tf The TensorFlow Ops + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, null, topK, classId, seed, type); + } + + /** + * Creates a Recall metric using a threshold value of {@link #DEFAULT_THRESHOLD}. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, Integer topK, Integer classId, long seed, Class type) { + this(tf, name, null, topK, classId, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()} + * + * @param tf The TensorFlow Ops + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, float threshold, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()} + * + * @param tf The TensorFlow Ops + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall( + Ops tf, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, thresholds, topK, classId, seed, type); + } + + /** + * Creates a Recall metric. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall( + Ops tf, + String name, + float threshold, + Integer topK, + Integer classId, + long seed, + Class type) { + this(tf, name, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Recall metric. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall( + Ops tf, + String name, + float[] thresholds, + Integer topK, + Integer classId, + long seed, + Class type) { + super(tf, name, seed); + this.type = type; + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); + float defaultThreshold = topK == null ? DEFAULT_THRESHOLD : MetricsHelper.NEG_INF; + + this.thresholds = thresholds == null ? new float[] {defaultThreshold} : thresholds; + this.topK = topK; + this.classId = classId; + + init(); + } + + /** Initializes the Variables */ + private void init() { + Ops tf = getTF(); + Zeros zeros = new Zeros<>(tf); + Operand zero = zeros.call(tf.constant(Shape.of(this.thresholds.length)), type); + if (truePositives == null) { + + truePositives = + tf.withName(truePositivesName) + .variable(zero); + initializers.add(tf.assign(truePositives, zero)); + } + + if (this.falseNegatives == null) { + + falseNegatives = + tf.withName(falseNegativesName) + .variable(zero); + initializers.add(tf.assign(falseNegatives, zero)); + } + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives); + + Operand tPredictions = cast(getTF(), predictions, type); + Operand tLabels = cast(getTF(), labels, type); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + + return MetricsHelper.updateConfusionMatrixVariables( + getTF(), + confusionMatrix, + Collections.EMPTY_MAP, + tLabels, + tPredictions, + this.thresholds, + this.topK, + this.classId, + tSampleWeights, + false, + null); + } + + @Override + public Operand result() { + Ops tf = getTF(); + Operand result = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + return this.thresholds.length == 1 + ? tf.slice(result, tf.constant(new int[] {0}), tf.constant(new int[1])) + : result; + } + + /** + * Gets the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return this.thresholds; + } + + /** + * Gets the topK value + * + * @return the topK value + */ + public Integer getTopK() { + return this.topK; + } + + /** + * Gets the class id + * + * @return the class id + */ + public Integer getClassId() { + return this.classId; + } + + /** + * Gets the truePositives variable + * + * @return the truePositives variable + */ + public Variable getTruePositives() { + return this.truePositives; + } + + /** + * Gets the falseNegatives variable + * + * @return the falseNegatives variable + */ + public Variable getFalseNegatives() { + return this.falseNegatives; + } + + /** + * Gets the truePositives variable name + * + * @return the truePositives variable name + */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** + * Gets the falseNegatives variable name + * + * @return the falseNegatives variable name + */ + public String getFalseNegativesName() { + return falseNegativesName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java new file mode 100644 index 00000000000..6c774f0c765 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Where; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class RecallAtPrecision + extends SensitivitySpecificityBase { + + private final float precision; + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()} and {@link + * #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param precision the precision. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public RecallAtPrecision(Ops tf, float precision, long seed, Class type) { + this(tf, null, precision, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number of + * thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric. If null, defaults to {@link Class#getSimpleName()} + * @param precision the precision. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public RecallAtPrecision(Ops tf, String name, float precision, long seed, Class type) { + this(tf, name, precision, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param precision the precision. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, Class type) { + this(tf, null, precision, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param precision the precision. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public RecallAtPrecision( + Ops tf, String name, float precision, int numThresholds, long seed, Class type) { + super(tf, name, precision, numThresholds, seed, type); + if (precision < 0f || precision > 1f) + throw new IllegalArgumentException("recall must be in the range [0, 1]."); + this.precision = precision; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + + Operand precisions = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falsePositives)); + Operand recalls = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + Operand isFeasible = + tf.math.greaterEqual(precisions, cast(tf, tf.constant(this.value), getType())); + Where feasible = tf.where(isFeasible); + Operand feasibleExists = tf.math.greater(tf.size(feasible), tf.constant(0)); + + Operand gather = + tf.expandDims(tf.gather(recalls, feasible, tf.constant(0)), tf.constant(0)); + return tf.select( + feasibleExists, + tf.reduceMax(gather, allAxes(tf, gather)), + cast(tf, tf.constant(0), getType())); + } + + /** + * Gets the precision + * + * @return the precision + */ + public float getPrecision() { + return precision; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java new file mode 100644 index 00000000000..2133642564b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes root mean squared error metric between labels> and predictions + * . + * + * @param The data type for the metric result + */ +public class RootMeanSquaredError< T extends TNumber> extends Mean< T> { + + /** + * Creates a RootMeanSquaredError metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public RootMeanSquaredError(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a RootMeanSquaredError metric + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public RootMeanSquaredError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + } + + /** {@inheritDoc} */ + @Override + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); + tPredictions = ops.getTarget(); + tLabels = ops.getLabels(); + + Operand errorSquared = + cast(getTF(), getTF().math.squaredDifference(tPredictions, tLabels), getResultType()); + + return super.updateStateList(errorSquared, tSampleWeights); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + return getTF().math.sqrt(getTF().math.divNoNan(this.total, this.count)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java new file mode 100644 index 00000000000..7cf694868e6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -0,0 +1,150 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best sensitivity where sensitivity is >= specified value. + * + *

Sensitivity measures the proportion of actual positives that are correctly + * identified as such (tp / (tp + fn)). + * + *

Specificity measures the proportion of actual negatives that are correctly + * identified as such (tn / (tn + fp)). + * + *

This metric creates four local variables, truePositives, trueNegatives + * , falsePositives and falseNegatives that are used to compute the + * sensitivity at the given specificity. The threshold for the given specificity value is computed + * and used to evaluate the corresponding sensitivity. + * + *

If sampleWeights is null>, weights default to 1. Use sample_weight + * of 0 to mask values. + * + * @see Additional information + * about specificity and sensitivity + * @param The data type for the metric result + */ +public class SensitivityAtSpecificity + extends SensitivitySpecificityBase { + + private final float specificity; + + /** + * Creates a SpecificityAtSensitivity metric with a name of {@link Class#getSimpleName()} and + * {@link #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param specificity the specificity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity(Ops tf, float specificity, long seed, Class type) { + this(tf, null, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a SpecificityAtSensitivity metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number + * of thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param specificity the specificity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity( + Ops tf, String name, float specificity, long seed, Class type) { + this(tf, name, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param specificity the specificity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * specificity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity( + Ops tf, float specificity, int numThresholds, long seed, Class type) { + this(tf, null, specificity, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param specificity the specificity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * specificity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity( + Ops tf, String name, float specificity, int numThresholds, long seed, Class type) { + super(tf, name, specificity, numThresholds, seed, type); + if (specificity < 0f || specificity > 1f) + throw new IllegalArgumentException("specificity must be in the range [0, 1]."); + this.specificity = specificity; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + Operand specificities = + tf.math.divNoNan( + this.trueNegatives, tf.math.add(this.trueNegatives, this.falsePositives)); + Operand sub = + tf.math.sub(specificities, cast(tf, tf.constant(this.getValue()), getType())); + Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); + minIndex = tf.expandDims(minIndex, tf.constant(0)); + + Operand trueSlice = tf.slice(this.truePositives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(this.falseNegatives, minIndex, tf.constant(new int[] {1})); + return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + } + + /** + * Gets the specificity + * + * @return the specificity + */ + public float getSpecificity() { + return specificity; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java new file mode 100644 index 00000000000..156a4995b02 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -0,0 +1,135 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.op.math.Equal; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import java.util.Collections; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Calculates how often predictions matches integer labels. + * + *

You can provide logits of classes as predictions, since argmax of logits and probabilities are + * same. + * + *

This metric creates two local variables, `total` and `count` that are used to compute the + * frequency with which predictions matches labels. This frequency is ultimately returned as `sparse + * categorical accuracy`: an idempotent operation that simply divides `total` by `count`. + * + *

If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values.' + * + *

Usage: + * + *

+ * + *

+ * SparseCategoricalAccuracy m = new tf.keras.metrics.SparseCategoricalAccuracy();
+ * m.update_state(tf.constant(new float[][] {{2}, {1}},
+ *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, [{0.05f, 0.95f, 0f}});
+ * Operand<TFloat32>> result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 0.5
+ * 
+ * + *
+ * m.reset_states()
+ * m.update_state(
+ *     tf.constant(new float[][] {{2}, {1}},
+ *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}},
+ *     tf.constant(new float[] {0.7f, 0.3f});
+ * result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 0.3
+ * 
+ * + *

Usage with tf.keras API: + * + *

+ * Model model = new tf.keras. models.Model(inputs, outputs);
+ * model.compile(
+ *     "sgd",
+ *     loss="mse",
+ *     metrics=["sparse_categorical_accuracy"]);
+ * 
+ * + * @param The data type for the metric result + */ +public class SparseCategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a SparseCategoricalAccuracy metric, using name of {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result + */ + public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a SparseCategoricalAccuracy metric. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null use {@link Class#getSimpleName()} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type of the metric result. + */ + public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + super.setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, + Operand predictions) { + + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Shape predShape = predictions.asOutput().shape(); + Shape labelsShape = labels.asOutput().shape(); + long predictionsRank = predShape.numDimensions(); + long labelsRank = labelsShape.numDimensions(); + + if (predictionsRank != Shape.UNKNOWN_SIZE + && labelsRank != Shape.UNKNOWN_SIZE + && labelsShape.size((int) labelsRank - 1) == 1) { + tLabels = getTF().squeeze(tLabels, Squeeze.axis(Collections.singletonList(labelsRank - 1L))); + } + Operand argMaxPred = + cast( + getTF(), + getTF().math.argMax(tPredictions, getTF().constant(-1L), TInt64.class), + getResultType()); + + Equal equals = getTF().math.equal(tLabels, argMaxPred); + return getTF().dtypes.cast(equals, getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java new file mode 100644 index 00000000000..7db290530cd --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** @param The data type for the metric result */ +public class SparseTopKCategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * top elements to look at for computing accuracy. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Metrics.sparseTopKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java new file mode 100644 index 00000000000..59f6f44c1f2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -0,0 +1,151 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best specificity where sensitivity is >= specified value. Sensitivity + * measures the proportion of actual positives that are correctly identified as such + * (tp / (tp + fn)). + * + *

Specificity measures the proportion of actual negatives that are correctly + * identified as such (tn / (tn + fp)). + * + *

This metric creates four local variables, truePositives, trueNegatives + * , falsePositives and falseNegatives that are used to compute the + * specificity at the given sensitivity. The threshold for the given sensitivity value is computed + * and used to evaluate the corresponding specificity. + * + *

If sampleWeights is null>, weights default to 1. Use sample_weight + * of 0 to mask values. + * + * @see Additional information + * about specificity and sensitivity + * @param The data type for the metric result + */ +public class SpecificityAtSensitivity + extends SensitivitySpecificityBase { + + private final float sensitivity; + + /** + * Creates a SpecificityAtSensitivity metric with a name of {@link Class#getSimpleName()} and + * {@link #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity(Ops tf, float sensitivity, long seed, Class type) { + this(tf, null, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a SpecificityAtSensitivity metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number + * of thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity( + Ops tf, String name, float sensitivity, long seed, Class type) { + this(tf, name, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * sensitivity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity( + Ops tf, float sensitivity, int numThresholds, long seed, Class type) { + this(tf, null, sensitivity, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * sensitivity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity( + Ops tf, String name, float sensitivity, int numThresholds, long seed, Class type) { + super(tf, name, sensitivity, numThresholds, seed, type); + if (sensitivity < 0f || sensitivity > 1f) + throw new IllegalArgumentException("sensitivity must be in the range [0, 1]."); + this.sensitivity = sensitivity; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + + Ops tf = getTF(); + + Operand sensitivities = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + Operand sub = + tf.math.sub(sensitivities, cast(tf, tf.constant(this.getValue()), getType())); + Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); + minIndex = tf.expandDims(minIndex, tf.constant(0)); + + Operand trueSlice = tf.slice(this.trueNegatives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(this.falsePositives, minIndex, tf.constant(new int[] {1})); + return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + } + + /** + * Gets the sensitivity + * + * @return the sensitivity + */ + public float getSensitivity() { + return sensitivity; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java new file mode 100644 index 00000000000..4312d7a97f0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.Reduce; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the (weighted) sum of the given values. + * + *

For example, if values is [1, 3, 5, 7] then the sum is 16. If the + * weights were specified as [1, 1, 0, 0], then the sum would be 4. + * + *

This metric creates one variable, total, that is used to compute the sum of + * values. This is ultimately returned as sum. + * + *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. + * + + */ +public class Sum extends Reduce { + + /** + * Creates a Sum metric with a name of {@link Class#getSimpleName()} + * + * @param tf The TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Sum(Ops tf, long seed, Class type) { + super(tf, null, MetricReduction.SUM, seed, type); + } + + /** + * Creates a Sum metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric instance. If null, defaults to {@link Class#getSimpleName()} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Sum(Ops tf, String name, long seed, Class type) { + super(tf, name, MetricReduction.SUM, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java new file mode 100644 index 00000000000..d2db4f368ac --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Computes the poisson loss metric between labels and predictions. + * + * @param The data type for the metric result + */ +public class TopKCategoricalAccuracy + extends MeanMetricWrapper implements LossMetric { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * top elements to look at for computing accuracy. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Metrics.topKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java new file mode 100644 index 00000000000..de6428fed88 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of true negatives. + * + *

If sampleWeights is given, calculates the sum of the weights of true negatives. + * This metric creates one local variable, accumulator that is used to keep track of + * the number of true negatives. + * + *

If sampleWeightsnull, weights + * default to 1. Use + * sampleWeights of 0 to mask values. + * + * @param The data type for the metric result + */ +public class TrueNegatives + extends ConfusionMatrixConditionCount { + + /** + * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a TrueNegatives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TrueNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a TrueNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.TRUE_NEGATIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java new file mode 100644 index 00000000000..c573b6b5719 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of true positives. + * + *

If sampleWeights is given, calculates the sum of the weights of true positives. + * This metric creates one local variable, accumulator that is used to keep track of + * the number of true positives. + * + *

If sampleWeightsnull, weights + * default to 1. Use + * sampleWeights of 0 to mask values. + * @param The data type for the metric result + */ +public class TruePositives + extends ConfusionMatrixConditionCount< T> { + + /** + * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a TruePositives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TruePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a TruePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.TRUE_POSITIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java new file mode 100644 index 00000000000..c9e762d05d4 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -0,0 +1,186 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Abstract base class that calculates the value of the given confusion matrix condition based on + * labels and predictions. + * + * @param The data type for the metric result + */ +public abstract class ConfusionMatrixConditionCount extends Metric { + public static final String ACCUMULATOR = "accumulator"; + public static final float DEFAULT_THRESHOLD = 0.5f; + private final ConfusionMatrixEnum confusionMatrixCond; + private final float[] thresholds; + private final String accumulatorName; + private final Class type; + private Variable accumulator; + private Assign initializer; + + /** + * Creates a ConfusionMatrixConditionCount type of Metric, using a threshold of {@link + * #DEFAULT_THRESHOLD} + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param confusionMatrixCond the confusion matrix condition to calculate + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public ConfusionMatrixConditionCount( + Ops tf, String name, ConfusionMatrixEnum confusionMatrixCond, long seed, Class type) { + this(tf, name, confusionMatrixCond, DEFAULT_THRESHOLD, seed, type); + } + /** + * Creates a ConfusionMatrixConditionCount type of Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param confusionMatrixCond the confusion matrix condition to calculate + * @param threshold a threshold value in [0, 1]. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * true, below is false). One metric value is generated for each + * threshold value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public ConfusionMatrixConditionCount( + Ops tf, + String name, + ConfusionMatrixEnum confusionMatrixCond, + float threshold, + long seed, + Class type) { + this(tf, name, confusionMatrixCond, new float[] {threshold}, seed, type); + } + + /** + * Creates a ConfusionMatrixConditionCount type of Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param confusionMatrixCond the confusion matrix condition to calculate + * @param thresholds threshold values in [0, 1]. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * true, below is false). One metric value is generated for each + * threshold value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public ConfusionMatrixConditionCount( + Ops tf, + String name, + ConfusionMatrixEnum confusionMatrixCond, + float[] thresholds, + long seed, + Class type) { + super(tf, name, seed); + accumulatorName = this.getVariableName(ACCUMULATOR); + this.type = type; + this.confusionMatrixCond = confusionMatrixCond; + this.thresholds = thresholds; + init(); + } + + private void init() { + Shape variableShape = Shape.of(this.thresholds.length); + + Zeros zeros = new Zeros<>(getTF()); + accumulator = + getTF() + .withName(getAccumulatorName()) + .variable(zeros.call(getTF().constant(variableShape), type)); + initializer = getTF().assign(accumulator, zeros.call(getTF().constant(variableShape), type)); + } + + /** + * Gets the initializer for the accumulator variable + * + * @return the initializer for the accumulator variable + */ + public Assign getInitializer() { + return initializer; + } + + /** {@inheritDoc} */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Operand tLabels = cast(getTF(), labels, type); + Operand tPredictions = cast(getTF(), predictions, type); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + return new ArrayList<>( + MetricsHelper.updateConfusionMatrixVariables( + getTF(), + Collections.singletonMap(confusionMatrixCond, accumulator), + Collections.singletonMap(confusionMatrixCond, initializer), + tLabels, + tPredictions, + thresholds, + null, + null, + tSampleWeights, + false, + null)); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + return getTF().identity(accumulator); + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return initializer; + } + + /** + * get the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return this.thresholds; + } + + /** @return the accumulatorName */ + public String getAccumulatorName() { + return accumulatorName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java new file mode 100644 index 00000000000..b76356661a9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +/** Enumerate the values for a confusion matrix. */ +public enum ConfusionMatrixEnum { + /** These are cases in which the prediction is true, and reality is true. */ + TRUE_POSITIVES("tp"), + /** These are cases in which the prediction is false, and reality is true. */ + FALSE_POSITIVES("fp"), + /** These are cases in which the prediction is true, and reality is false. */ + TRUE_NEGATIVES("tn"), + /** These are cases in which the prediction is false, and reality is false. */ + FALSE_NEGATIVES("fn"); + + private final String abbrev; + + /** Creates a ConfusionMatrixEnum */ + ConfusionMatrixEnum(String abbrev) { + this.abbrev = abbrev; + } + + /** + * Gets the ConfusionMatrixEnum for this enum value, regardless of case. + * + * @param item either the name of the enumeration value or the abbreviation. + * @return ConfusionMatrixEnum for this enum value, or null if not found. + */ + public static ConfusionMatrixEnum get(String item) { + ConfusionMatrixEnum cm = valueOf(item.toUpperCase()); + if (cm == null) { + for (ConfusionMatrixEnum m : values()) { + if (m.getAbbreviation().equals(item.toLowerCase())) { + return m; + } + } + } + return null; + } + + /** Gets the abbreviation for this enum value */ + public String getAbbreviation() { + return abbrev; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 8a352322f52..cbb24933967 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,20 +15,26 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; +import org.tensorflow.framework.utils.SparseTensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Stack; +import org.tensorflow.op.core.*; import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.TopK; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; +import java.util.*; +import java.util.concurrent.atomic.AtomicLong; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -212,7 +218,383 @@ public static Operand broadcastWeights( return ctf.math.mul(weights, tf.onesLike(values)); } - // aliases for mean + /** + * Checks that all the Symbolic Shapes are consistent. + * + * @param tf the TensorFlow Ops + * @param symbols the list of Symbolic Shapes + * @param message the error message if the shapes are not consistent. + * @return a list of Operands to check the consistency of the symbolic shapes ready to add to a + * control dependency. + */ + public static List assertShapes( + Ops tf, List> symbols, String message) { + List updateOperations = new ArrayList<>(); + // check that the symbolic shape rank matches the operands rank. + symbols.forEach( + symbol -> { + Operand operand = symbol.getOperand(); + int rank = symbol.rank(); + Rank tfRank = tf.rank(operand); + Op assertion = + tf.withSubScope("assertShapes-1") + .assertThat( + tf.math.equal(tfRank, tf.constant(rank)), + Collections.singletonList(tf.constant(message))); + updateOperations.add(assertion); + }); + + Map dict = new HashMap<>(); + + // check that each operand's dimension size equals the corresponding symbolic shape's dimensions + // size + symbols.forEach( + symbol -> { + AtomicLong ll = new AtomicLong(); + symbol + .getSymbols() + .forEach( + s -> { + Long size = dict.get(s); + if (size == null) { + size = symbol.getOperand().asOutput().shape().size((int) ll.get()); + dict.put(s, size); + } + Op assertion = + tf.withSubScope("assertShapes-2") + .assertThat( + tf.math.equal( + tf.shape.size( + symbol.getOperand(), + tf.constant(ll.getAndIncrement()), + TInt64.class), + tf.constant(size)), + Collections.singletonList(tf.constant(message))); + updateOperations.add(assertion); + }); + }); + + return updateOperations; + } + + /** + * Returns an op to update the given confusion matrix variables. + * + *

For every pair of values in labels and predictions: + * + *

+   * TRUE_POSITIVES:  labels == true and predictions > thresholds
+   * FALSE_POSITIVES: labels == true and predictions <= thresholds
+   * TRUE_NEGATIVES:  labels == false and predictions <= thresholds
+   * FALSE_NEGATIVE:  labels == false and predictions > thresholds
+   * 
+ * + *

The results will be weighted and added together. When multiple thresholds are provided, we + * will repeat the same for every threshold. + * + *

For estimation of these metrics over a stream of data, the function creates an `update_op` + * operation that updates the given variables. + * + *

If sampleWeight is null, weights default to 1. Use weights of 0 to + * mask values. + * + * @param tf the TensorFlow Ops + * @param variablesToUpdate Map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding variables to update as values. + * @param varInitializers Map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding initializer Operands to initializer the corresponding variables from + * variablesToUpdate. + * @param labels the labels, will be cast to {@link TBool} + * @param predictions the predictions whose values are in the range [0, 1]. + * @param thresholds thresholds in `the range [0, 1], or {@link #NEG_INF} (used when + * topK is set) + * @param topK Optional, indicates that the positive labels should be limited to the top k + * predictions, may be null. + * @param classId Optional, limits the prediction and labels to the specified class + * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as + * labels, and must be broadcast to labels (i.e., all dimensions + * must be either 1, or the same as the corresponding labels + * dimension). + * @param multiLabel indicates whether multidimensional prediction/labels should be treated as + * multilabel responses, or flattened into a single label. When true, the values of + * variablesToUpdate must have a second dimension equal to the number of labels and + * predictions, and those tensors must not be RaggedTensors. + * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied + * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES + * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. + * @param the data type for the variables + * @throws IllegalArgumentException If predictions and labels have + * mismatched shapes, or if sampleWeight is not null>and its shape + * doesn't match predictions + * @return an op to update the given confusion matrix variables. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static List updateConfusionMatrixVariables( + Ops tf, + Map> variablesToUpdate, + Map> varInitializers, + Operand labels, + Operand predictions, + float[] thresholds, + Integer topK, + Integer classId, + Operand sampleWeight, + boolean multiLabel, + Operand labelWeights) { + if (multiLabel && labelWeights != null) + throw new IllegalArgumentException( + "labelWeights for multilabel data should be handled outside of updateConfusionMatrixVariables when multiLabel is true."); + + if (variablesToUpdate == null || variablesToUpdate.isEmpty()) { + return Collections.EMPTY_LIST; + } + + Operand lLabels = labels; + Operand lPredictions = predictions; + Operand lSampleWeight = sampleWeight; + + Operand numThresholds; + Operand oneThresh; + if (multiLabel) { + numThresholds = tf.shape.size(lLabels, tf.constant(0)); + oneThresh = tf.math.equal(tf.constant(1), tf.constant(thresholds.length)); + } else { + // TODO handle Ragged Tensors???? + // [y_pred, + // y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], + // sampleWeights) + numThresholds = tf.constant(thresholds.length); + oneThresh = tf.constant(true); + } + + List controlOps = new ArrayList<>(); + Operand axes = allAxes(tf, lPredictions); + controlOps.add( + tf.withSubScope("updateConfusionMatrixVariables-1") + .assertThat( + tf.reduceAll( + tf.math.greaterEqual( + lPredictions, cast(tf, tf.constant(0), lPredictions.type())), + axes), + Collections.singletonList(tf.constant("predictions must be >= 0")))); + controlOps.add( + tf.withSubScope("updateConfusionMatrixVariables-2") + .assertThat( + tf.reduceAll( + tf.math.lessEqual(lPredictions, cast(tf, tf.constant(1), lPredictions.type())), + axes), + Collections.singletonList(tf.constant("predictions must be <= 1")))); + + LossTuple result = + LossesHelper.squeezeOrExpandDimensions(tf, lLabels, lPredictions, lSampleWeight); + lPredictions = result.getTarget(); + lLabels = result.getLabels(); + lSampleWeight = result.getSampleWeights(); + + if (!lPredictions.shape().isCompatibleWith(lLabels.shape())) + throw new IllegalArgumentException( + String.format( + "Shapes %s and %s are incompatible)", + lPredictions.shape().toString(), lLabels.asOutput().shape().toString())); + + if (topK != null) { + lPredictions = filterTopK(tf, lPredictions, topK); + } + + if (classId != null) { + lLabels = tf.squeeze(tf.gather(lLabels, tf.constant(new int[] {classId}), tf.constant(1))); + lPredictions = + tf.squeeze(tf.gather(lPredictions, tf.constant(new int[] {classId}), tf.constant(1))); + lLabels = tf.expandDims(lLabels, tf.constant(0)); + lPredictions = tf.expandDims(lPredictions, tf.constant(0)); + } + org.tensorflow.op.core.Shape predShape = tf.shape(lPredictions); + Operand numPredictions = + tf.reshape(tf.shape.size(lPredictions, tf.constant(0)), tf.constant(Shape.scalar())); + Operand numLabels = + tf.select( + tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), + tf.constant(1), + tf.reduceProd( + tf.shape.takeLast( + predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), + tf.constant(0))); + Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); + + Operand predictionsExtraDim; + Operand labelsExtraDim; + if (multiLabel) { + predictionsExtraDim = tf.expandDims(lPredictions, tf.constant(0)); + labelsExtraDim = tf.expandDims(cast(tf, lLabels, TBool.class), tf.constant(0)); + } else { + predictionsExtraDim = tf.reshape(lPredictions, tf.constant(Shape.of(1, -1))); + labelsExtraDim = tf.reshape(cast(tf, lLabels, TBool.class), tf.constant(Shape.of(1, -1))); + } + List> threshPretileShape; + List> threshTiles; + List> dataTiles; + if (multiLabel) { + threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); + + threshTiles = Arrays.asList(tf.constant(1), numPredictions, threshLabelTile); + dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1)); + } else { + threshPretileShape = Arrays.asList(numThresholds, tf.constant(-1)); + Operand mul = tf.math.mul(numPredictions, numLabels); + threshTiles = Arrays.asList(tf.constant(1), mul); + dataTiles = Arrays.asList(numThresholds, tf.constant(1)); + } + + Operand thresholdsReshaped = + tf.reshape( + cast(tf, tf.constant(thresholds), predictions.type()), tf.stack(threshPretileShape)); + Operand threshTilesShape = tf.stack(threshTiles); + Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); + Operand predsTiled = tf.tile(predictionsExtraDim, tf.stack(dataTiles)); + + // Compare predictions and threshold. + Operand predIsPos = tf.math.greater(predsTiled, threshTiled); + // Tile labels by number of thresholds + Operand labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles)); + Operand weightsTiled; + if (lSampleWeight != null) { + lSampleWeight = + tf.broadcastTo(cast(tf, lSampleWeight, predictions.type()), tf.shape(lPredictions)); + weightsTiled = tf.tile(tf.reshape(lSampleWeight, tf.stack(threshTiles)), tf.stack(dataTiles)); + } else { + weightsTiled = null; + } + + if (labelWeights != null) { + Operand lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0)); + lLabelWeights = tf.broadcastTo(cast(tf, lLabelWeights, labelWeights.type()), lPredictions); + Operand labelWeightsTiled = + tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles)); + if (weightsTiled == null) { + weightsTiled = labelWeightsTiled; + } else { + weightsTiled = tf.math.mul(weightsTiled, labelWeightsTiled); + } + } + + Map loopVars = new HashMap<>(); + loopVars.put(ConfusionMatrixEnum.TRUE_POSITIVES, new Operand[] {labelIsPos, predIsPos}); + Variable update_tn = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES); + Variable update_fp = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES); + Variable update_fn = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES); + + Operand predIsNeg = null; + Operand labelIsNeg; + if (update_fn != null || update_tn != null) { + predIsNeg = tf.math.logicalNot(predIsPos); + loopVars.put(ConfusionMatrixEnum.FALSE_NEGATIVES, new Operand[] {labelIsPos, predIsNeg}); + } + + if (update_fp != null || update_tn != null) { + labelIsNeg = tf.math.logicalNot(labelIsPos); + loopVars.put(ConfusionMatrixEnum.FALSE_POSITIVES, new Operand[] {labelIsNeg, predIsPos}); + if (update_tn != null) { + loopVars.put(ConfusionMatrixEnum.TRUE_NEGATIVES, new Operand[] {labelIsNeg, predIsNeg}); + } + } + + final Operand weightsTiledF = weightsTiled; + loopVars + .keySet() + .forEach( + (c) -> { + if (variablesToUpdate.containsKey(c)) { + Operand[] op = loopVars.get(c); + // op[0] = label, op[1] == prediction + controlOps.add( + weightedAssignAdd( + tf, + op[0], + op[1], + weightsTiledF, + variablesToUpdate.get(c), + varInitializers.get(c))); + } + }); + + return controlOps; + } + + /** + * Creates an Operand that adds the values by taking the logical and of labels and predictions to + * the specified confusion matrix variable. + * + * @param tf The TensorFlow Ops + * @param labels the labels + * @param predictions the predictions + * @param weights the weights applied to the logical and result, may be null + * @param variable the variable to update + * @param initializer the variable initializer to be applied to the variable, may be null. + * @param the data type for the variable. + * @return an Operand that updates the variable. + */ + private static Operand weightedAssignAdd( + Ops tf, + Operand labels, + Operand predictions, + Operand weights, + Variable variable, + Assign initializer) { + Class type = variable.type(); + Operand labelAndPred = cast(tf, tf.math.logicalAnd(labels, predictions), type); + + if (weights != null) { + Operand lWeights = cast(tf, weights, type); + labelAndPred = tf.math.mul(labelAndPred, lWeights); + } + Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); + Operand assignAdd; + if (initializer != null) { + Ops tfc = + tf.withSubScope("weightedAssignAdd") + .withControlDependencies(Collections.singletonList(initializer)); + assignAdd = tfc.assignAdd(variable, valueSum); + } else { + assignAdd = tf.assignAdd(variable, valueSum); + } + return assignAdd; + } + + /** + * Filters top-k values in the last dim of x and set the rest to NEG_INF. + * + *

Used for computing top-k prediction values in dense labels (which has the same shape as + * predictions) for recall and precision top-k metrics. + * + * @param tf The TensorFlow Ops + * @param x the tensor with any dimensions to filter + * @param topK the number of values to keep. + * @param the data type for x and the return value. + * @return the topK prediction values. + */ + private static Operand filterTopK(Ops tf, Operand x, int topK) { + Class type = x.type(); + Shape xShape = x.shape(); + TopK top = tf.nn.topK(x, tf.constant(topK), TopK.sorted(false)); + OneHot oneHot = + tf.oneHot( + top.indices(), + cast(tf, tf.constant(xShape.size(xShape.numDimensions() - 1)), TInt32.class), + tf.constant(1), + tf.constant(0), + OneHot.axis(-1L)); + Operand topKMask = cast(tf, tf.reduceSum(oneHot, tf.constant(-2)), type); + + // x * top_k_mask + NEG_INF * (1 - top_k_mask) + Operand add1 = tf.math.mul(x, topKMask); + Operand add2 = + tf.math.mul( + cast(tf, tf.constant(NEG_INF), type), + tf.math.sub(cast(tf, tf.constant(1), type), topKMask)); + return tf.math.add(add1, add2); + } + + // alias for mean /** * Calculate the mean of the operand, along all axes and keepDims is false @@ -279,6 +661,103 @@ public static Operand mean( return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } + public static + LossTuple raggedAssertCompatibleAndGetFlatValues( + Ops tf, Operand labels, Operand predictions) { + // TODO handle ragged Tensors + Operand tLabels = cast(tf, labels, predictions.type()); + return new LossTuple<>(tLabels, predictions); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + * @param tf the TensorFlow Ops + * @param labels 1-D `Tensor` of real labels for the classification task. + * @param predictions 1-D `Tensor` of predictions for a given classification. + * @param numClasses The possible number of labels the classification task can have. + * @param weights optional weights to be applied to the confusion matrix + * @param type Data type of the confusion matrix. + * @param the type of Operands + * @return A Operand of type type with shape [n, n] + * representing the confusion matrix, where n is the number of possible labels in + * the classification task. + * @throws IllegalArgumentException If both predictions and labels do + * not have compatible shapes, or if weights is notnull and its + * shape is not compatible with predictions. + */ + public static Operand confusionMatrix( + Ops tf, + Operand labels, + Operand predictions, + Operand numClasses, + Operand weights, + Class type) { + if (!predictions.shape().isCompatibleWith(labels.shape())) + throw new IllegalArgumentException( + String.format( + "Prediction shape %s is not compatible with labels shape %s", + predictions.shape().toString(), labels.shape().toString())); + tf = tf.withSubScope("confusionMatrix"); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, predictions, labels, null); + Operand lPredictions = cast(tf, ops.getTarget(), TInt64.class); + Operand lLabels = cast(tf, ops.getLabels(), TInt64.class); + + List labelControls = new ArrayList<>(); + List predictionControls = new ArrayList<>(); + + labelControls.add( + tf.assertThat( + tf.reduceAny(tf.math.greaterEqual(lLabels, tf.constant(0L)), allAxes(tf, lLabels)), + Collections.singletonList(tf.constant("`labels` contains negative values")))); + + predictionControls.add( + tf.assertThat( + tf.reduceAny( + tf.math.greaterEqual(lPredictions, tf.constant(0L)), allAxes(tf, lPredictions)), + Collections.singletonList(tf.constant("`predictions` contains negative values")))); + if (numClasses == null) { + numClasses = + tf.math.maximum( + tf.reduceMax(lPredictions, allAxes(tf, lPredictions)), + tf.reduceMax(lLabels, allAxes(tf, lLabels))); + } else { + labelControls.add( + tf.assertThat( + tf.reduceAny(tf.math.less(lLabels, numClasses), allAxes(tf, lLabels)), + Collections.singletonList(tf.constant("``labels` out of bounds")))); + predictionControls.add( + tf.assertThat( + tf.reduceAny(tf.math.less(lPredictions, numClasses), allAxes(tf, lPredictions)), + Collections.singletonList(tf.constant("``predictions` out of bounds")))); + } + + if (weights != null) { + if (!lPredictions.shape().isCompatibleWith(weights.shape())) { + throw new IllegalArgumentException( + String.format( + "Prediction shape %s is not compatible with weights shape %s", + lPredictions.shape().toString(), weights.shape().toString())); + } + } + + Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls); + lLabels = tfc.identity(lLabels); + + tfc = tf.withSubScope("confusionMatrixPredicitons").withControlDependencies(predictionControls); + lPredictions = tfc.identity(lPredictions); + + Operand shape = tf.stack(Arrays.asList(numClasses, numClasses)); + Operand indices = tf.stack(Arrays.asList(lLabels, lPredictions), Stack.axis(1L)); + Operand values = + weights == null ? cast(tf, tf.onesLike(lPredictions), type) : cast(tf, weights, type); + SparseTensor cmSparse = new SparseTensor<>(indices, values, shape); + Operand zeroMatrix = tf.zeros(shape, type); + + return tf.sparse.sparseTensorDenseAdd( + cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix); + } + /** * Calculate the mean of the operand, along all axes and keepDims is false * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java new file mode 100644 index 00000000000..3949ede822a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -0,0 +1,277 @@ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.*; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Abstract base class for computing sensitivity and specificity. + * + * @param The data type for the metric result + */ +public abstract class SensitivitySpecificityBase extends Metric { + + public static final int DEFAULT_NUM_THRESHOLDS = 200; + + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; + public static final String TRUE_NEGATIVES = "TRUE_NEGATIVES"; + public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; + protected final int numThresholds; + protected final float value; + protected final float[] thresholds; + private final String truePositivesName; + private final String falsePositivesName; + private final String trueNegativesName; + private final String falseNegativesName; + private final Class type; + protected Variable truePositives; + protected Variable falsePositives; + protected Variable trueNegatives; + protected Variable falseNegatives; + + private Assign truePositivesInitializer; + private Assign falsePositivesInitializer; + private Assign trueNegativesInitializer; + private Assign falseNegativesInitializer; + + /** + * Creates a SensitivitySpecificityBase Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric instance, if null then {@link Class#getSimpleName()} is used + * @param value A scalar value in range `[0, 1]` + * @param numThresholds The number of thresholds to use for matching the given recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0. + */ + protected SensitivitySpecificityBase( + Ops tf, String name, float value, int numThresholds, long seed, Class type) { + super(tf, name, seed); + if (numThresholds <= 0) throw new IllegalArgumentException("numThresholds must be > 0."); + this.type = type; + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); + this.trueNegativesName = this.getVariableName(TRUE_NEGATIVES); + this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); + + this.value = value; + this.numThresholds = numThresholds; + + if (this.numThresholds == 1) { + this.thresholds = new float[] {0.5f}; + } else { + this.thresholds = new float[numThresholds]; + for (int i = 0; i < numThresholds - 2; i++) { + this.thresholds[i + 1] = (i + 1f) / (float) (numThresholds - 1); + } + this.thresholds[numThresholds - 1] = 1f; + } + init(); + } + + /** Initializes the Variables */ + private void init() { + Ops tf = getTF(); + Zeros zeros = new Zeros<>(tf); + Shape varShape = Shape.of(numThresholds); + Operand zero = zeros.call(tf.constant(varShape), type); + + if (this.getTruePositives() == null) { + + truePositives = tf.withName(truePositivesName).variable(zero); + truePositivesInitializer = tf.assign(truePositives, zero); + } + if (this.getFalsePositives() == null) { + + falsePositives = tf.withName(falsePositivesName).variable(zero); + falsePositivesInitializer = tf.assign(falsePositives, zero); + } + if (this.getTrueNegatives() == null) { + + trueNegatives = tf.withName(trueNegativesName).variable(zero); + trueNegativesInitializer = tf.assign(trueNegatives, zero); + } + if (this.getFalseNegatives() == null) { + + falseNegatives = tf.withName(falseNegativesName).variable(zero); + falseNegativesInitializer = tf.assign(falseNegatives, zero); + } + } + + public Op initializeVariables() { + List varInitializers = new ArrayList<>(); + + if(truePositivesInitializer != null ) { + varInitializers.add(truePositivesInitializer); + } + if(falsePositivesInitializer != null ) { + varInitializers.add(falsePositivesInitializer); + } + if(trueNegativesInitializer != null ) { + varInitializers.add(trueNegativesInitializer); + } + if(falseNegativesInitializer != null ) { + varInitializers.add(falseNegativesInitializer); + } + + return getTF().withControlDependencies(varInitializers).noOp(); + + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf = getTF(); + Operand tLabels = cast(tf, labels, type); + Operand tPredictions = cast(tf, predictions, type); + Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; + + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.getTruePositives()); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.getFalsePositives()); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.getTrueNegatives()); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.getFalseNegatives()); + + return MetricsHelper.updateConfusionMatrixVariables( + tf, + confusionMatrix, + Collections.EMPTY_MAP, + tLabels, + tPredictions, + this.getThresholds(), + null, + null, + tSampleWeights, + false, + null); + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return initializeVariables(); + } + + /** + * Gets the truePositives variable + * + * @return the truePositives + */ + public Variable getTruePositives() { + return truePositives; + } + + /** + * Gets the falsePositives variable + * + * @return the falsePositives truePositives + */ + public Variable getFalsePositives() { + return falsePositives; + } + + /** + * Gets the trueNegatives variable + * + * @return the trueNegatives truePositives + */ + public Variable getTrueNegatives() { + return trueNegatives; + } + + /** + * Gets the falseNegatives variable + * + * @return the falseNegatives truePositives + */ + public Variable getFalseNegatives() { + return falseNegatives; + } + + /** + * Gets the numThresholds + * + * @return the numThresholds + */ + public int getNumThresholds() { + return numThresholds; + } + + /** + * Gets the value + * + * @return the value + */ + public float getValue() { + return value; + } + + /** + * Gets the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return thresholds; + } + + /** + * Gets the truePositives variable name + * + * @return the truePositivesName + */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** + * Gets the falsePositives variable name + * + * @return the falsePositivesName + */ + public String getFalsePositivesName() { + return falsePositivesName; + } + + /** + * Gets the trueNegatives variable name + * + * @return the trueNegativesName + */ + public String getTrueNegativesName() { + return trueNegativesName; + } + + /** + * Gets the falseNegatives variable name + * + * @return the falseNegativesName + */ + public String getFalseNegativesName() { + return falseNegativesName; + } + + /** + * Gets the type + * + * @return the type + */ + public Class getType() { + return type; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java new file mode 100644 index 00000000000..d28185ae041 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class SymbolicShape { + private Operand operand; + private List symbols = new ArrayList<>(); + + public SymbolicShape(Operand operand, String... symbols) { + this.operand = operand; + this.symbols.addAll(Arrays.asList(symbols)); + } + + /** @return the operand */ + public Operand getOperand() { + return operand; + } + + /** @param operand the operand to set */ + public void setOperand(Operand operand) { + this.operand = operand; + } + + /** @return the symbols */ + public List getSymbols() { + return symbols; + } + + /** @param symbols the symbols to set */ + public void setSymbols(List symbols) { + this.symbols = symbols; + } + + public int rank() { + return this.symbols.size(); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java new file mode 100644 index 00000000000..09752798ad5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -0,0 +1,186 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class WeightsBroadcastOps { + + private static final String ASSERT_BROADCASTABLE_ERROR_PREFIX = + "weights can not be broadcast to values."; + + /** + * Asserts that `weights` can be broadcast to `values` + * + * @param tf the TensorFlow Ops + * @param weights `Tensor` of weights. + * @param values `Tensor` of values to which weights are applied. + * @return `Operation` raising `InvalidArgumentError` if `weights` has incorrect shape. `no_op` if + * static checks determine `weights` has correct shape. + * @param the type of weights and values + * @throws IllegalArgumentException If static checks determine `weights` has incorrect shape. + */ + @SuppressWarnings("unchecked") + public static Op assertBroadcastable( + Ops tf, Operand weights, Operand values) { + Operand weightsShape = tf.shape(weights); + Operand weightsRank = tf.rank(weights); + Shape weightsShapeStatic = weights.shape(); + int weightsRankStatic = weightsShapeStatic.numDimensions(); + + Operand valuesShape = tf.shape(values); + Operand valuesRank = tf.rank(values); + Shape valuesShapeStatic = values.asOutput().shape(); + int valuesRankStatic = valuesShapeStatic.numDimensions(); + + if (weightsRankStatic != -1 && valuesRankStatic != -1) { + if (weightsRankStatic == 0) { + return tf.withSubScope("staticScalarCheckSuccess") + .withControlDependencies(Collections.EMPTY_LIST) + .noOp(); + } + if (weightsRankStatic != valuesRankStatic) { + throw new IllegalArgumentException( + String.format( + "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", + ASSERT_BROADCASTABLE_ERROR_PREFIX, + valuesRankStatic, + weightsRankStatic, + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); + } + + for (int i = 0; i < valuesRankStatic; i++) { + if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + throw new IllegalArgumentException( + String.format( + "%s Mismatch at dim %s. values.shape=%s weights.shape=%s.", + ASSERT_BROADCASTABLE_ERROR_PREFIX, + i, + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); + } + } + return tf.withSubScope("staticDimsCheckSuccess") + .withControlDependencies(Collections.EMPTY_LIST) + .noOp(); + } + // Dynamic checks. + Operand is_scalar = tf.math.equal(weightsRank, tf.constant(0)); + List> data = + Arrays.asList( + tf.constant(ASSERT_BROADCASTABLE_ERROR_PREFIX), + tf.constant("weights.shape="), + weightsShape, + tf.constant("values.shape="), + valuesShape, + tf.constant("is_scalar="), + is_scalar); + + Operand isValidShape = + tf.select( + is_scalar, + is_scalar, + hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); + + return tf.assertThat(isValidShape, data); + } + + /** + * Check to see that weights and values have the same rank, if they do, then check each + * corresponding dim of each. + * + * @param tf The TensorFlow Ops + * @param weightsRank the rank operand for the weights + * @param weightsShape the shape operand for the weights + * @param valuesRank the rank operand for the values + * @param valuesShape the shape operand for the values + * @return a boolean Operand, true if both shapes have the same rank, and each dimension is the + * same + */ + private static Operand hasValidNonscalarShape( + Ops tf, + Operand weightsRank, + Operand weightsShape, + Operand valuesRank, + Operand valuesShape) { + tf = tf.withSubScope("hasValidNonscalarShape"); + Operand isSameRank = tf.math.equal(valuesRank, weightsRank); + return tf.select(isSameRank, hasValidDims(tf, weightsShape, valuesShape), isSameRank); + } + + /** + * Checks that each dimension of the two shapes are the same + * + * @param tf the TensorFlow Ops + * @param weightsShape the shape of the weights + * @param valuesShape the shape of the values + * @return a boolean Operand, true if all the dimensions of the two shapes are the same. + */ + private static Operand hasValidDims( + Ops tf, Operand weightsShape, Operand valuesShape) { + tf = tf.withSubScope("hasInvalidDims"); + Operand diff = tf.reduceSum(tf.math.sub(weightsShape, valuesShape), tf.constant(0)); + return tf.math.equal(tf.constant(0), diff); + } + + /** + * Broadcast `weights` to the same shape as `values`. + * + *

This returns a version of weights following the same broadcast rules as + * mul(weights, + * values), but limited to the weights shapes allowed by assertBroadcastable + * When computing a weighted average, use this function to broadcast weights before + * summing them; e.g., reduceSum(w * v) / reduceSum(_broadcast_weights(w, v)). + * + * @param tf the TensorFlow ops + * @param weights `Tensor` whose shape is able to be broadcast to `values` + * @param values Tensor` of any shape + * @param the type of Operand + * @return weights broadcast to values shape + */ + public static Operand broadcastWeights( + Ops tf, Operand weights, Operand values) { + tf = tf.withSubScope("broadcast_weights"); + Operand tValues = cast(tf, values, weights.type()); + + Shape weightsShape = weights.shape(); + Shape valuesShape = tValues.shape(); + + if (!weightsShape.hasUnknownDimension() + && !valuesShape.hasUnknownDimension() + && weightsShape.isCompatibleWith(valuesShape)) { + return weights; + } + + Op dependencies = assertBroadcastable(tf, weights, tValues); + Ops tf1 = + tf.withSubScope("assertBroadcastable") + .withControlDependencies(Collections.singletonList(dependencies)); + return tf1.math.mul(weights, tf.onesLike(tValues)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java new file mode 100644 index 00000000000..9dee070eea9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java @@ -0,0 +1,75 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.utils; + +import org.tensorflow.Operand; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * This is a helper class that represents a sparse tensor who's attributes may be passed to + * {@link org.tensorflow.op.Ops#sparse} methods. + * + * @param the type of the SparseTensor + */ +public class SparseTensor { + private final Operand indices; + private final Operand values; + private final Operand denseShape; + + /** + * Creates a SparseTensor + * + * @param indices A 2-D int64 tensor of shape `[N, ndims]`, which specifies the + * indices of the elements in the sparse tensor that contain nonzero values + * @param values A 1-D tensor of any type and shape `[N]`, which supplies the + * values for each element in `indices`. + * @param denseShape A 1-D int64 tensor of shape `[ndims]`, which specifies the + * dense_shape of the sparse tensor + * @throws IllegalArgumentException When building an eager SparseTensor if `dense_shape` is + * unknown or contains unknown elements (None or -1). + */ + public SparseTensor (Operand indices, Operand values, Operand denseShape) { + this.indices = indices; + this.values = values; + this.denseShape = denseShape; + } + + /** + * Gets the indices for the Sparse Tensor + * @return the indices + */ + public Operand getIndices() { + return indices; + } + + /** + * Get the values for the Sparse Tensor + * @return the values + */ + public Operand getValues() { + return values; + } + + /** + * Gets the dense shape for the Sparse Tensor + * + * @return the denseShape + */ + public Operand getDenseShape() { + return denseShape; + } + +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java new file mode 100644 index 00000000000..88825b5f32e --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -0,0 +1,324 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +import static org.junit.jupiter.api.Assertions.*; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class AUCTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float epsilon = 1e-4F; + + int numThresholds = 3; + float[] predArray = new float[] {0f, 0.5f, 0.3f, 0.9f}; + int[] trueArray = new int[] {0, 0, 1, 1}; + float[] sampleWeight = new float[] {1, 2, 3, 4}; + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(predArray); + Operand yTrue = tf.constant(trueArray); + AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + + session.run(tf.init()); + + Op update = instance.updateState(yTrue, yPred, null); + + for (int i = 0; i < 10; i++) { + session.run(update); + } + + Operand result = instance.result(); + + for (int i = 0; i < 10; i++) { + session.evaluate(result, instance.result()); + } + } + } + + @Test + public void basicTestSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + assertEquals(numThresholds, instance.getNumThresholds()); + float[] expectedThresholds = new float[] {-1e-7f, 0.5f, 1 + 1e-7f}; + assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); + + instance.resetStates(); + Operand yPred = tf.constant(new float[] {0, 0, 1, 1}); + Operand yTrue = tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}); + Operand sampleWeights = tf.constant(new float[] {1, 0, 0, 1}); + + Op update = instance.updateState(yTrue, yPred, sampleWeights); + session.run(update); + Operand result = instance.result(); + session.evaluate(1.0f, result); + } + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yTrue = cast(tf, tf.constant(this.trueArray), TFloat32.class); + AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + session.run(tf.init()); + + Op update = instance.updateState(yTrue, yTrue, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(1f, result); + } + } + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + Operand result = instance.result(); + + // float expectedResult = (0.75f * 1 + 0.25f * 0); + session.evaluate(0.75f, result); + } + } + + @Test + public void testManualThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + AUC instance = new AUC<>(tf, new float[] {0.5f}, 1001L, TFloat32.class); + float[] expectedThresholds = new float[] {-AUC.EPSILON, 0.5f, 1 + AUC.EPSILON}; + assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + Operand result = instance.result(); + + // float expectedResult = (0.75f * 1 + 0.25f * 0); + session.evaluate(0.75f, result); + } + } + + @Test + public void testWeightedRocInterpolation() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + + float expectedResult = (0.78571427f * 1 + 0.2857145f * 0); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedRocMajoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.ROC, + AUCSummationMethod.MAJORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + + float expectedResult = (1.0f + .5714285f * 0f); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedRocMinoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.ROC, + AUCSummationMethod.MINORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + + float expectedResult = ( 0.5714285f + 0f * 0f); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedPrMajoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.PR, + AUCSummationMethod.MAJORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.4285715f + 0.5714285f; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedPrMinoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.PR, + AUCSummationMethod.MINORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.7f * 0.4285715f + 0f * 0.5714285f; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedPrInterpolation() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>(tf, this.numThresholds, AUCCurve.PR, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.916613f; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + new AUC<>(tf, -1, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testExtraDims() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // logits = scipy.special.expit(-np.array([[[-10., 10., -10.], [10., -10., 10.]], + // [[-12., 12., -12.], [12., -12., 12.]]], + // dtype=np.float32)) + float[][][] logitsArray = { + { + {9.99954602e-01f, 4.53978687e-05f, 9.99954602e-01f}, + {4.53978687e-05f, 9.99954602e-01f, 4.53978687e-05f} + }, + { + {9.99993856e-01f, 6.14417460e-06f, 9.99993856e-01f}, + {6.14417460e-06f, 9.99993856e-01f, 6.14417460e-06f} + } + }; + + long[][][] labelArray = { + {{1, 0, 0}, {1, 0, 0}}, + {{0, 1, 1}, {0, 1, 1}} + }; + + Operand logits = tf.constant(logitsArray); + Operand labels = tf.constant(labelArray); + + AUC instance = new AUC<>(tf, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(labels, logits, null); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.5f; + session.evaluate(expectedResult, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java new file mode 100644 index 00000000000..48cac95b8a6 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class AccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2, 3, 4}; + float[] predArray = {1, 2, 3, 4}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 1))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4F, total); + session.evaluate(4, count); + session.evaluate(1F, result); + } + } + + @Test + public void testSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {2, 1}; + float[] predArray = {2, 0}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(.5F, total); + session.evaluate(.7, count); + session.evaluate(0.71428573f, result); + } + } + + @Test + public void testVariableState() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {2, 1}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // 2nd run + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(1.4F, total); + session.evaluate(1.4, count); + session.evaluate(1.0F, result); + + // new instance same graph + instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + total = instance.getTotal(); + count = instance.getCount(); + result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // reset variables + session.run(instance.resetStates()); + result = instance.result(); + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java new file mode 100644 index 00000000000..e8d8350dcdc --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java @@ -0,0 +1,177 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class BinaryAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 0}; + float[] predArray = {1, 0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(2, count); + session.evaluate(1F, result); + } + } + + @Test + public void testPredictionSqueeze() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 0}; + float[] predArray = {1, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1, 1))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(4, count); + session.evaluate(0.5F, result); + } + } + + @Test + public void testSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 1}; + float[] predArray = {1, 0}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.5F, total); + session.evaluate(.7, count); + session.evaluate(0.71428573f, result); + } + } + + @Test + public void testVariableState() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {2, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.2F, total); + session.evaluate(.7, count); + session.evaluate(0.2857143F, result); + + // 2nd run + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(0.4F, total); + session.evaluate(1.4, count); + session.evaluate(0.2857143F, result); + + // new instance same graph + instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + total = instance.getTotal(); + count = instance.getCount(); + result = instance.result(); + session.evaluate(0.2F, total); + session.evaluate(.7, count); + session.evaluate(0.2857143F, result); + + // reset variables + session.run(instance.resetStates()); + session.evaluate(0.0, total); + session.evaluate(0.0, count); + + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(0.2F, total); + session.evaluate(.7, count); + session.evaluate(0.2857143F, result); + } + } + + @Test + public void testBinaryAccuracyAThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = + new BinaryAccuracy<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 1, 0, 0}; + float[] predArray = {0.9f, 0.6f, 0.4f, 0.8f}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 1))); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(4, count); + session.evaluate(0.5F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java new file mode 100644 index 00000000000..83990cbaebb --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java @@ -0,0 +1,156 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class CategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalAccuracy instance = + new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 0, 1, + 0, 1, 0 + }; + float[] predArray = { + 0.1f, 0.1f, 0.8f, + 0.05f, 0.95f, 0f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(2, count); + session.evaluate(1F, result); + } + } + + @Test + public void testSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalAccuracy instance = + new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 0, 1, + 0, 1, 0 + }; + float[] predArray = { + 0.1f, 0.1f, 0.8f, + 0.05f, 0.95f, 0f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + } + } + + @Test + public void testVariableState() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalAccuracy instance = + new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 0, 1, + 0, 1, 0 + }; + float[] predArray = { + 0.1f, 0.1f, 0.8f, + 0.05f, 0.95f, 0f + }; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // 2nd run + op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(1.4F, total); + session.evaluate(1.4, count); + session.evaluate(1.0F, result); + + // new instance same graph + instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + result = instance.result(); + total = instance.getTotal(); + count = instance.getCount(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // reset variables + session.run(instance.resetStates()); + session.evaluate(0, total); + session.evaluate(0, count); + session.evaluate(0, result); + + op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java new file mode 100644 index 00000000000..4bd8d99586e --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class FalseNegativesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + FalseNegatives instance = new FalseNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(3.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + FalseNegatives instance = new FalseNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(5.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + FalseNegatives instance = + new FalseNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {1.f, 4.f, 6.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = tf.constant(new double[][] {{3.0}, {5.0}, {7.0}, {4.0}}); + FalseNegatives instance = + new FalseNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {4., 16., 23.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java new file mode 100644 index 00000000000..2584c7a3244 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java @@ -0,0 +1,148 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class FalsePositivesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + FalsePositives instance = new FalsePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(7.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + FalsePositives instance = new FalsePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(14.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + FalsePositives instance = + new FalsePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {7.f, 4.f, 2.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = + tf.constant( + new double[][] { + {1.0, 2.0, 3.0, 5.0}, + {7.0, 11.0, 13.0, 17.0}, + {19.0, 23.0, 29.0, 31.0}, + {19.0, 23.0, 29.0, 31.0} + }); + FalsePositives instance = + new FalsePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {125., 42., 12.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java new file mode 100644 index 00000000000..fc08455d1c7 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +public class MeanIoUTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final long numClasses = 2L; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testUnweighted"); + Operand predictions = tf.constant(new long[] {0, 1, 0, 1}); + Operand labels = tf.constant(new long[] {0, 0, 1, 1}); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + double expected_result = (1. / (2. + 2. - 1.) + 1. / (2. + 2. - 1.)) / 2.; + session.evaluate(expected_result, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testWeighted"); + Operand predictions = tf.constant(new long[] {0, 1, 0, 1}); + Operand labels = tf.constant(new long[] {0, 0, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {0.2f, 0.3f, 0.4f, 0.1f}); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + float expected_result = (0.2f / (0.6f + 0.5f - 0.2f) + 0.1f / (0.4f + 0.5f - 0.1f)) / 2f; + session.evaluate(expected_result, result); + } + } + + @Test + public void testMultiDimInput() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testMultiDimInput"); + + Operand predictions = tf.constant(new long[][] {{0, 1}, {0, 1}}); + Operand labels = tf.constant(new long[][] {{0, 0}, {1, 1}}); + Operand sampleWeight = tf.constant(new float[][] {{0.2f, 0.3f}, {0.4f, 0.1f}}); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + float expected_result = (0.2f / (0.6f + 0.5f - 0.2f) + 0.1f / (0.4f + 0.5f - 0.1f)) / 2f; + session.evaluate(expected_result, result); + } + } + + @Test + public void testZeroValidEntries() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testZeroValidEntries"); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Operand result = instance.result(); + session.evaluate(0.0f, result); + } + } + + @Test + public void testZeroAndNonZeroEntries() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testZeroAndNonZeroEntries"); + Operand predictions = tf.constant(new float[] {1}); + Operand labels = tf.constant(new int[] {1}); + + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float expected_result = (0f + 1f / (1f + 1f - 1f)) / 1f; + session.evaluate(expected_result, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java new file mode 100644 index 00000000000..0bb9392b8b0 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class MeanRelativeErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] predArray = new float[][] {{2, 4, 6, 8}}; + float[][] trueArray = new float[][] {{1, 3, 2, 3}}; + Operand predictions = tf.constant(predArray); + Operand labels = tf.constant(trueArray); + + MeanRelativeError instance = + new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + double expected_result = 1.25; + session.evaluate(expected_result, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[] predArray = new float[] {2, 4, 6, 8}; + float[] trueArray = new float[] {1, 3, 2, 3}; + float[] sampleWeightArray = new float[] {0.2f, 0.3f, 0.5f, 0f}; + + Operand predictions = tf.constant(predArray); + Operand labels = tf.constant(trueArray); + Operand sampleWeight = tf.constant(sampleWeightArray); + + MeanRelativeError instance = + new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + double expectedResult = 1.3; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testZeroNormalizer() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[] predArray = new float[] {2, 4}; + int[] trueArray = new int[] {1, 3}; + + Operand predictions = tf.constant(predArray); + Operand labels = tf.constant(trueArray); + + MeanRelativeError instance = + new MeanRelativeError<>( + tf, cast(tf, tf.zerosLike(labels), TFloat32.class), 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + double expectedResult = 0; + session.evaluate(expectedResult, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java new file mode 100644 index 00000000000..ce473bbdf34 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class MeanTensorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand values = tf.constant(new long[] {100, 40}); + MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); + session.run(tf.init()); + Op update = instance.updateState(values, null); + session.run(update); + Operand result = instance.result(); + double[] expected_result = new double[] {100, 40}; + session.evaluate(expected_result, result); + + session.evaluate(expected_result, instance.getTotal()); + session.evaluate(new double[] {1, 1}, instance.getCount()); + + session.run(instance.resetStates()); + session.evaluate(new double[] {0, 0}, instance.getTotal()); + session.evaluate(new double[] {0, 0}, instance.getCount()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand values = tf.constant(new long[] {100, 30}); + MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); + session.run(tf.init()); + + // check scalar weight + Op update = instance.updateState(values, tf.constant(0.5f)); + session.run(update); + Operand result = instance.result(); + double[] expected_result = new double[] {100, 30}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {50, 15}, instance.getTotal()); + session.evaluate(new double[] {0.5, 0.5}, instance.getCount()); + + // check weights not scalar and weights rank matches values rank + values = tf.constant(new long[] {1, 5}); + update = instance.updateState(values, tf.constant(new double[] {1f, 0.2f})); + session.run(update); + result = instance.result(); + expected_result = new double[] {51 / 1.5, 16 / 0.7}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {51, 16}, instance.getTotal()); + session.evaluate(new double[] {1.5, .7}, instance.getCount()); + + // check weights broadcast + values = tf.constant(new long[] {1, 2}); + update = instance.updateState(values, tf.constant(0.5f)); + session.run(update); + result = instance.result(); + expected_result = new double[] {51.5 / 2, 17 / 1.2}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {51.5, 17}, instance.getTotal()); + session.evaluate(new double[] {2, 1.2}, instance.getCount()); + + // check weights squeeze + values = tf.constant(new long[] {1, 5}); + Operand sampleWeight = tf.constant(new double[][] {{1}, {0.2}}); + update = instance.updateState(values, sampleWeight); + session.run(update); + result = instance.result(); + expected_result = new double[] {52.5 / 3, 18 / 1.4}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {52.5, 18}, instance.getTotal()); + session.evaluate(new double[] {3, 1.4}, instance.getCount()); + } + } + + @Test + public void testWeightedExpand() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + // check weights expand + MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat32.class); + + Operand values = tf.constant(new long[][] {{1}, {5}}); + Operand sampleWeight = tf.constant(new float[] {1f, 0.2f}); + Op update = instance.updateState(values, sampleWeight); + session.run(update); + Operand result = instance.result(); + session.evaluate(tf.constant(new float[][] {{1f}, {5f}}), result); + session.evaluate(tf.constant(new float[][] {{1f}, {1f}}), instance.getTotal()); + session.evaluate(tf.constant(new float[][] {{1f}, {0.2f}}), instance.getCount()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java new file mode 100644 index 00000000000..a817a3dc5df --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class PrecisionAtRecallTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialPrecision = instance.result(); + + for (int i = 0; i < 10; i++) session.evaluate(initialPrecision, instance.result()); + } + } + + private int[][] generateRandomArray(int dim1, int dim2) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(2); + } + } + + return result; + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighRecall() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.8f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.5f, 0.4f, 0.5f, 0.6f, 0.8f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.8f, precision); + } + } + + @Test + public void testUnweightedLowRecall() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.15f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.5f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {2, 2, 1, 1, 1, 1, 1, 2, 2, 2}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(2.f / 3.f, precision); + } + } + + @Test + public void testInvalidSensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new PrecisionAtRecall<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new PrecisionAtRecall<>(tf, 0.7f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java new file mode 100644 index 00000000000..35962a568ca --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java @@ -0,0 +1,339 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +public class PrecisionTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Precision instance = + new Precision<>(tf, new float[] {0.3f, 0.72f}, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L)); + + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) { + session.run(update); + } + + Operand initialPrecision = instance.result(); + + for (int i = 0; i < 10; i++) { + session.evaluate(initialPrecision, instance.result()); + } + } + } + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new long[][] {{1, 0, 1, 0}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(0.5, precision); + } + } + + @Test + public void testUnweightedAllIncorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 0.5f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniformInt(tf.constant(Shape.of(100, 1)), tf.constant(0), tf.constant(2)); + Operand labels = tf.math.sub(tf.constant(1), predictions); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(0.0f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new long[][] {{1, 0, 1, 0}, {1, 0, 1, 0}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}, {1, 0, 0, 1}}); + Operand sampleWeight = tf.constant(new double[][] {{1, 2, 3, 4}, {4, 3, 2, 1}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand precision = instance.result(); + + double weightedTP = 3.0f + 4.0f; + double weightedPositives = (1.0f + 3.0f) + (4.0f + 2.0f); + double expectedPrecision = weightedTP / weightedPositives; + + session.evaluate(expectedPrecision, precision); + } + } + + @Test + public void testDivByZero() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new int[] {0, 0, 0, 0}); + Operand labels = tf.constant(new int[] {0, 0, 0, 0}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(0, precision); + } + } + + @Test + public void testUnweightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = + new Precision<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1f, 0f, 0.6f, 0f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + float[] expected = new float[] {0.5f, 0.f}; + + session.evaluate(expected, precision); + } + } + + @Test + public void testWeightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = + new Precision<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1f, 0f}, {0.6f, 0f}}); + Operand labels = tf.constant(new long[][] {{0, 1}, {1, 0}}); + Operand sampleWeight = tf.constant(new float[][] {{4, 0}, {3, 1}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand precision = instance.result(); + + float weightedTP = 0f + 3.f; + float weightedPositives = (0f + 3.f) + (4.f + 0.f); + float expectedPrecision = weightedTP / weightedPositives; + + Float[] expected = new Float[] {expectedPrecision, 0f}; + session.evaluate(expected, precision); + } + } + + @Test + public void testMultipleUpdates() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = + new Precision<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1f, 0f}, {0.6f, 0f}}); + Operand labels = tf.constant(new long[][] {{0, 1}, {1, 0}}); + Operand sampleWeight = tf.constant(new double[][] {{4, 0}, {3, 1}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + for (int i = 0; i < 2; i++) session.run(update); + Operand precision = instance.result(); + + double weighted_tp = (0 + 3.) + (0 + 3.); + double weighted_positives = ((0 + 3.) + (4. + 0.)) + ((0 + 3.) + (4. + 0.)); + double expected_precision = weighted_tp / weighted_positives; + + double[] expected = new double[] {expected_precision, 0f}; + session.evaluate(expected, precision); + } + } + + @Test + public void testUnweightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK to 3 + Precision instance = + new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(1.0f / 3.0f, precision); + } + } + + @Test + public void testWeightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK to 3 + Precision instance = + new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[] {0.2f, 0.1f, 0.4f, 0f, 0.2f}); + Operand labels = tf.constant(new long[] {0, 1, 1, 0, 1}); + Operand sampleWeight = tf.constant(new float[][] {{1, 4, 2, 3, 5}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.4f, 0.2f, 0.2f}}); + labels = tf.constant(new long[][] {{1, 0, 1, 1, 1}}); + update = instance.updateState(labels, predictions, tf.constant(3.f)); + session.run(update); + + Operand precision = instance.result(); + + float tp = (2f + 5f) + (3f + 3f); + float predicted_positives = (1f + 2f + 5f) + (3f + 3f + 3f); + float expected_precision = tp / predicted_positives; + session.evaluate(expected_precision, precision); + } + } + + @Test + public void testUnweightedClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set classId to 2 + Precision instance = + new Precision<>(tf, null, null, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0f, 0f, 0.2f}}); + labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + labels = tf.constant(new long[][] {{0, 1, 0, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + precision = instance.result(); + + session.evaluate(0.5f, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(1, instance.getFalsePositives()); + } + } + + @Test + public void testUnweightedTopKAndClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK and classId to 2 + Precision instance = + new Precision<>(tf, null, 2, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + + predictions = tf.constant(new float[][] {{1f, 1f, 0.9f, 1f, 1f}}); + labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + } + } + + @Test + public void testUnweightedTopKAndThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK to 2 + Precision instance = + new Precision<>(tf, 0.7f, 2, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 1}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java new file mode 100644 index 00000000000..bd3a5273668 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java @@ -0,0 +1,207 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class RecallAtPrecisionTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + labels = tf.math.mul(labels, tf.constant(2.0f)); + + Op update = instance.updateState(labels, predictions); + + for (int i = 0; i < 10; i++) { + session.run(update); + } + + Operand initialPrecision = instance.result(); + + for (int i = 0; i < 10; i++) { + session.evaluate(initialPrecision, instance.result()); + } + } + } + + private int[][] generateRandomArray(int dim1, int dim2, int maxVal) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(maxVal); + } + } + + return result; + } + + @Test + public void test_unweighted_all_correct() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1, 2); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighPrecision() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.75f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] { + 0.05f, 0.1f, 0.2f, 0.3f, 0.3f, 0.35f, 0.4f, 0.45f, 0.5f, 0.6f, 0.9f, 0.95f + }); + Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.5f, precision); + } + } + + @Test + public void testUnweightedLowPrecision() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 2.0f / 3f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] { + 0.05f, 0.1f, 0.2f, 0.3f, 0.3f, 0.35f, 0.4f, 0.45f, 0.5f, 0.6f, 0.9f, 0.95f + }); + Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(5.f / 6f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.75f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.1f, 0.2f, 0.3f, 0.5f, 0.6f, 0.9f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {1, 2, 1, 2, 1, 2, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.6f, precision); + } + } + + @Test + public void testUnachievablePrecision() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 2.0f / 3f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = tf.constant(new float[] {0.1f, 0.2f, 0.3f, 0.9f}); + Operand labels = tf.constant(new long[] {1, 1, 0, 0}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + // The highest possible precision is 1/2 which is below the required + session.evaluate(0f, precision); + } + } + + @Test + public void test_invalid_sensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new RecallAtPrecision<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void test_invalid_num_thresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new RecallAtPrecision<>(tf, 0.7f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java new file mode 100644 index 00000000000..b9d067a6ed2 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java @@ -0,0 +1,341 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Random; + +public class RecallTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, new float[] {0.3f, 0.72f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class); + Operand labels = + tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class); + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialRecall = instance.result(); + for (int i = 0; i < 10; i++) session.evaluate(initialRecall, instance.result()); + } + } + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1, 0, 1, 0}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.5f, instance.result()); + } + } + + private int[][] generateRandomArray(int dim1, int dim2, int maxInt) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(maxInt); + } + } + + return result; + } + + @Test + public void testUnweightedAllIncorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] array = generateRandomArray(100, 1, 2); + Operand predictions = tf.dtypes.cast(tf.constant(array), TFloat32.class); + Operand labels = + tf.dtypes.cast(tf.math.sub(tf.constant(1), tf.constant(array)), TFloat32.class); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.f, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[][] { + {1, 0, 1, 0}, + {0, 1, 0, 1} + }); + Operand labels = + tf.constant( + new float[][] { + {0, 1, 1, 0}, + {1, 0, 0, 1} + }); + + Operand sampleWeights = + tf.constant( + new float[][] { + {1, 2, 3, 4}, + {4, 3, 2, 1} + }); + Op update = instance.updateState(labels, predictions, sampleWeights); + session.run(update); + + float weightedTp = 3.0f + 1.0f; + float weightedT = (2.0f + 3.0f) + (4.0f + 1.0f); + float expectedRecall = weightedTp / weightedT; + + session.evaluate(expectedRecall, instance.result()); + } + } + + @Test + public void testDivByZero() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[] {0, 0, 0, 0}); + Operand labels = tf.constant(new float[] {0, 0, 0, 0}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0f, instance.result()); + } + } + + @Test + public void testUnweightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1, 0, 0.6f, 0}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Float[] expected = new Float[] {0.5f, 0f}; + session.evaluate(expected, instance.result()); + } + } + + @Test + public void testWeightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); + Operand predictions = tf.constant(new float[][] {{1, 0}, {0.6f, 0}}); + Operand weights = tf.constant(new float[][] {{1, 4}, {3, 2}}); + + Op update = instance.updateState(labels, predictions, weights); + session.run(update); + + float weightedTp = 0 + 3.f; + float weightedPositives = (0 + 3.f) + (4.f + 0.f); + float expectedRecall = weightedTp / weightedPositives; + float[] expected = new float[] {expectedRecall, 0f}; + session.evaluate(expected, instance.result()); + } + } + + @Test + public void testMultipleUpdates() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); + Operand predictions = tf.constant(new float[][] {{1, 0}, {0.6f, 0}}); + Operand weights = tf.constant(new float[][] {{1, 4}, {3, 2}}); + + Op update = instance.updateState(labels, predictions, weights); + for (int i = 0; i < 2; i++) session.run(update); + + float weightedTp = (0f + 3.f) + (0f + 3.f); + float weightedPositives = ((0f + 3.f) + (4.f + 0.f)) + ((0f + 3.f) + (4.f + 0.f)); + float expectedRecall = weightedTp / weightedPositives; + float[] expected = new float[] {expectedRecall, 0f}; + session.evaluate(expected, instance.result()); + } + } + + @Test + public void testUnweightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0f, 1f, 1f, 0f, 0f}}); + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.5f, instance.result()); + } + } + + @Test + public void testWeightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 1}}); + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.4f, 0f, 0.2f}}); + Operand weights = tf.constant(new float[][] {{1, 4, 2, 3, 5}}); + + Op update = instance.updateState(labels, predictions, weights); + session.run(update); + + labels = tf.constant(new float[][] {{1, 0, 1, 1, 1}}); + predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.4f, 0.2f, 0.2f}}); + weights = tf.constant(3.f); + + update = instance.updateState(labels, predictions, weights); + session.run(update); + + float weightedTp = (2 + 5) + (3 + 3); + float weightedPositives = (4 + 2 + 5) + (3 + 3 + 3 + 3); + float expectedRecall = weightedTp / weightedPositives; + session.evaluate(expectedRecall, instance.result()); + } + } + + @Test + public void testUnweightedClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, null, null, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(0f, instance.getFalseNegatives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0f, 0f, 0.2f}}); + labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(1f, instance.getFalseNegatives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + labels = tf.constant(new float[][] {{0, 1, 0, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(1f, instance.getFalseNegatives()); + } + } + + @Test + public void testUnweightedTopKAndClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, null, 2, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0, 0.2f}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(1f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(0f, instance.getFalseNegatives()); + + predictions = tf.constant(new float[][] {{1, 1, 0.9f, 1, 1}}); + labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(1f, instance.getFalseNegatives()); + } + } + + @Test + public void testUnweightedTopKAndThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, 0.7f, 2, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new float[][] {{1, 1, 1, 0, 1}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.25f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(3f, instance.getFalseNegatives()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java new file mode 100644 index 00000000000..c9ced9f5946 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +public class RootMeanSquaredErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RootMeanSquaredError instance = + new RootMeanSquaredError<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[] {2, 4, 6}); + Operand predictions = tf.constant(new float[] {1, 3, 2}); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(18, total); + session.evaluate(3, count); + session.evaluate(Math.sqrt(6), result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RootMeanSquaredError instance = + new RootMeanSquaredError<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand labels = tf.constant(new float[][] {{2, 4, 6, 8}}); + Operand predictions = tf.constant(new float[][] {{1, 3, 2, 3}}); + Operand sampleWeight = tf.constant(new double[][] {{0, 1, 0, 1}}); + + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(26, total); + session.evaluate(2, count); + session.evaluate(Math.sqrt(13), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java new file mode 100644 index 00000000000..a65dc3b53da --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java @@ -0,0 +1,185 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class SensitivityAtSpecificityTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + labels = tf.math.mul(labels, tf.constant(2.0f)); + + // instance.setDebug(session.getGraphSession()); + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialSensitivity = instance.result(); + + for (int i = 0; i < 10; i++) session.evaluate(initialSensitivity, instance.result()); + + // instance.setDebug(null); + + } + } + + private int[][] generateRandomArray(int dim1, int dim2) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(2); + } + } + + return result; + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighSpecificity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.8f, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.45f, 0.5f, 0.8f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.8, precision); + } + } + + @Test + public void testUnweightedLowSpecificity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.6f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.4f, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + Operand sampleWeight = tf.constant(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.675, precision); + } + } + + @Test + public void testInvalidSensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SensitivityAtSpecificity<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SensitivityAtSpecificity<>(tf, 0.7f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java new file mode 100644 index 00000000000..ff5834eda8e --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java @@ -0,0 +1,184 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class SpecificityAtSensitivityTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + + // instance.setDebug(session.getGraphSession()); + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialSpecificity = instance.result(); + + for (int i = 0; i < 10; i++) session.evaluate(initialSpecificity, instance.result()); + } + } + + private int[][] generateRandomArray(int dim1, int dim2) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(2); + } + } + + return result; + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = tf.constant(trueArray); + labels = tf.math.mul(labels, tf.constant(2)); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighSensitivity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.8f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.45f, 0.5f, 0.8f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.4f, precision); + } + } + + @Test + public void testUnweightedLowSensitivity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.4f, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.6f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.4f, precision); + } + } + + @Test + public void testInvalidSensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SpecificityAtSensitivity<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SpecificityAtSensitivity<>(tf, 0.4f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java new file mode 100644 index 00000000000..941f882b8c8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java @@ -0,0 +1,113 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SumTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + /** Test of call method, of class Sum. */ + @Test + public void testUnWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Sum instance = new Sum<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + assertEquals(TFloat32.class, instance.getResultType()); + session.evaluate(0f, instance.getTotal()); + + Op update = instance.updateState(tf.constant(100f), null); + session.run(update); + session.evaluate(100f, instance.result()); + session.evaluate(100f, instance.getTotal()); + + update = instance.updateState(tf.constant(new float[] {1, 5}), null); + session.run(update); + session.evaluate(106f, instance.result()); + session.evaluate(106f, instance.getTotal()); + + session.run(instance.resetStates()); + session.evaluate(0f, instance.getTotal()); + } + } + + @Test + public void testSumWithSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Sum instance = new Sum<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + // check scalar weight + Op op = instance.updateState(tf.constant(100f), tf.constant(0.5)); + session.run(op); + Operand result = instance.result(); + session.evaluate(50.0, instance.getTotal()); + session.evaluate(50.0, result); + + // check weights not scalar and weights rank matches values rank + op = + instance.updateState(tf.constant(new float[] {1, 5}), tf.constant(new double[] {1, 0.2})); + session.run(op); + result = instance.result(); + session.evaluate(52., instance.getTotal()); + session.evaluate(52., result); + + // check weights broadcast + op = instance.updateState(tf.constant(new float[] {1, 2}), tf.constant(0.5)); + session.run(op); + result = instance.result(); + session.evaluate(53.5, instance.getTotal()); + session.evaluate(53.5, result); + + // check weights squeeze + op = + instance.updateState( + tf.constant(new float[] {1, 5}), tf.constant(new double[][] {{1}, {0.2}})); + session.run(op); + result = instance.result(); + session.evaluate(55.5, instance.getTotal()); + session.evaluate(55.5, result); + + // check weights expand + op = + instance.updateState( + tf.constant(new float[][] {{1}, {5}}), tf.constant(new double[] {1, 0.2})); + session.run(op); + result = instance.result(); + session.evaluate(57.5, instance.getTotal()); + session.evaluate(57.5, result); + + // heck values reduced to the dimensions of weight + op = + instance.updateState( + tf.constant(new float[][][] {{{1.f, 2.f}, {3.f, 2.f}, {0.5f, 4.f}}}), + tf.constant(new double[] {0.5})); + session.run(op); + result = instance.result(); + session.evaluate(63.75, instance.getTotal()); + session.evaluate(63.75, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java new file mode 100644 index 00000000000..023796ba367 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java @@ -0,0 +1,103 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class TopKCategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrectness() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand labels = tf.constant(new float[][] {{0, 0, 1}, {0, 1, 0}}); + Operand predictions = + tf.constant(new double[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1., instance.result()); + + // With `k` < 5. + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted1", 1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + + // With `k` > 5. + labels = + tf.constant( + new float[][] { + {0, 0, 1, 0, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0} + }); + predictions = + tf.constant( + new double[][] { + {0.5f, 0.9f, 0.1f, 0.7f, 0.6f, 0.5f, 0.4f}, + {0.05f, 0.95f, 0f, 0f, 0f, 0f, 0f} + }); + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted6", 6, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testWeighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand labels = + tf.constant( + new double[][] { + {1, 0, 2}, + {1, 0, 0}, + {0, 0, 1} + }); + Operand predictions = + tf.constant( + new double[][] { + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f} + }); + + Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + session.evaluate(1., instance.result()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java new file mode 100644 index 00000000000..1a68c2ed8b8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class TrueNegativesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + TrueNegatives instance = new TrueNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(3.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + TrueNegatives instance = new TrueNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(4.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + TrueNegatives instance = + new TrueNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {2.f, 5.f, 7.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = tf.constant(new double[][] {{0.0, 2.0, 3.0, 5.0}}); + TrueNegatives instance = + new TrueNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {5., 15., 23.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java new file mode 100644 index 00000000000..c22c1245d97 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class TruePositivesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + TruePositives instance = new TruePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(7.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + TruePositives instance = new TruePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(12.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + TruePositives instance = + new TruePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {6.f, 3.f, 1.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = tf.constant(37.); + TruePositives instance = + new TruePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {222., 111., 37.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 63d666f8640..4330fa0aed7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -63,6 +63,7 @@ private void testValid( TestSession testSession, Ops tf, Operand weights, Operand values, Class type) { Op staticOp = MetricsHelper.assertBroadcastable(tf, weights, values); + testSession.run(staticOp); // dynamic test Operand weightsPlaceholder = tf.placeholder(type); From a3aa3c42ff61014cd0c8f3020a686709635754b7 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 7 Feb 2021 13:46:01 -0500 Subject: [PATCH 07/97] Reformat code, fix javadoc --- .../org/tensorflow/framework/metrics/AUC.java | 61 ++++++++++--------- .../framework/metrics/AUCSummationMethod.java | 8 +-- .../framework/metrics/Accuracy.java | 2 +- .../framework/metrics/BinaryAccuracy.java | 2 +- .../framework/metrics/FalseNegatives.java | 9 ++- .../framework/metrics/FalsePositives.java | 8 +-- .../tensorflow/framework/metrics/MeanIoU.java | 6 +- .../framework/metrics/MeanRelativeError.java | 15 +++-- .../framework/metrics/MeanTensor.java | 2 +- .../framework/metrics/Precision.java | 55 +++++++++-------- .../framework/metrics/PrecisionAtRecall.java | 32 +++++++--- .../tensorflow/framework/metrics/Recall.java | 46 +++++++------- .../framework/metrics/RecallAtPrecision.java | 39 ++++++++---- .../metrics/RootMeanSquaredError.java | 13 ++-- .../metrics/SensitivityAtSpecificity.java | 23 +++---- .../metrics/SparseCategoricalAccuracy.java | 20 +++--- .../metrics/SpecificityAtSensitivity.java | 25 ++++---- .../org/tensorflow/framework/metrics/Sum.java | 2 - .../metrics/TopKCategoricalAccuracy.java | 12 ++-- .../framework/metrics/TrueNegatives.java | 6 +- .../framework/metrics/TruePositives.java | 7 +-- .../metrics/impl/ConfusionMatrixEnum.java | 13 +++- .../framework/metrics/impl/MetricsHelper.java | 10 +-- .../impl/SensitivitySpecificityBase.java | 19 +++--- .../metrics/impl/WeightsBroadcastOps.java | 2 +- 25 files changed, 239 insertions(+), 198 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 62311c3cda5..da89167e1f3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -34,27 +34,29 @@ /** * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. * - *

This metric creates four local variables, truePositives`, trueNegatives`, - * falsePositives` and falseNegatives` that are used to compute the AUC. To discretize the AUC - * curve, a linearly spaced set of thresholds is used to compute pairs of recall and precision - * values. The area under the ROC-curve is therefore computed using the height of the recall values - * by the false positive rate, while the area under the PR-curve is the computed using the height of - * the precision values by the recall. + *

This metric creates four local variables, truePositives, trueNegatives + * , falsePositives and falseNegatives that are used to compute the + * AUC. To discretize the AUC curve, a linearly spaced set of thresholds is used to compute pairs of + * recall and precision values. The area under the ROC-curve is therefore computed using the height + * of the recall values by the false positive rate, while the area under the PR-curve is the + * computed using the height of the precision values by the recall. * - *

This value is ultimately returned as auc, an idempotent operation that computes the area - * under a discretized curve of precision versus recall values (computed using the aforementioned - * variables). The numThresholds variable controls the degree of discretization with larger - * numbers of thresholds more closely approximating the true AUC. The quality of the approximation - * may vary dramatically depending on numThresholds`. The thresholds parameter can be used to - * manually specify thresholds which split the predictions more evenly. + *

This value is ultimately returned as auc, an idempotent operation that computes + * the area under a discretized curve of precision versus recall values (computed using the + * aforementioned variables). The numThresholds variable controls the degree of + * discretization with larger numbers of thresholds more closely approximating the true AUC. The + * quality of the approximation may vary dramatically depending on numThresholds. The + * thresholds parameter can be used to manually specify thresholds which split the + * predictions more evenly. + * + *

For best results, predictions should be distributed approximately uniformly in + * the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor + * if this is not the case. Setting summationMethod to minoring or + * majoring can help quantify the error in the approximation by providing lower or upper + * bound estimate of the AUC. + * + *

Usage:
* - *

For best results, predictions should be distributed approximately uniformly in the range [0, - * 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if this is not - * the case. Setting summationMethod to minoring or majoring can help quantify the error in - * the approximation by providing lower or upper bound estimate of the AUC. - *

- *

- * Usage:
*

  * AUC m = new  getTF().keras.metrics.AUC( getTF(), 3);
  * m.updateState( getTF().constant(new float[] {0, 0, 1,1}),
@@ -64,10 +66,11 @@
  * // tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
  * // recall = [1, 0.5, 0], fpRate = [1, 0, 0]
  * // auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75
- * Operand<TFloat32> result = m.result();
+ * Operand<TFloat32> result = m.result();
  * System.out.println(result.data().getFloat());
  * 0.75
  * 
+ * *
  * m.resetStates()
  * m.updateState( getTF().constant(new float[] {0, 0, 1, 1}),
@@ -170,7 +173,7 @@ public AUC(Ops tf, String name, long seed, Class type) {
    *
    * @param tf The TensorFlow Ops
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param seed the seed for random number generation. An initializer created with a given seed
    *     will always produce the same random tensor for a given shape and data type.
    * @param type the data type for the confusion matrix variables.
@@ -224,7 +227,7 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) {
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME}
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param seed the seed for random number generation. An initializer created with a given seed
    *     will always produce the same random tensor for a given shape and data type.
    * @param type the data type for the confusion matrix variables.
@@ -279,7 +282,7 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) {
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME}
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param seed the seed for random number generation. An initializer created with a given seed
@@ -336,7 +339,7 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C
    *
    * @param tf The TensorFlow Ops
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param seed the seed for random number generation. An initializer created with a given seed
@@ -392,7 +395,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type)
    *
    * @param tf The TensorFlow Ops
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param summationMethod Specifies the Riemann summation method used
@@ -442,7 +445,7 @@ public AUC(
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME}
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param summationMethod Specifies the Riemann summation method used,
@@ -462,7 +465,7 @@ public AUC(
   }
 
   /**
-   * Creates an AUC (Area under the curve) metric. using null> for the numThresholds,
+   * Creates an AUC (Area under the curve) metric. using null for the numThresholds,
    * false for multiLabel, and null for labelWeights.
    *
    * @param tf The TensorFlow Ops
@@ -493,7 +496,7 @@ public AUC(
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}.
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS}
+   *     must be > 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS}
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param summationMethod Specifies the Riemann summation method used
@@ -577,7 +580,7 @@ public AUC(
                       .greaterEqual(
                           labelWeights, cast(getTF(), getTF().constant(0), labelWeights.type())),
                   Collections.singletonList(
-                      getTF().constant("All values of `labelWeights` must be non-negative.")));
+                      getTF().constant("All values of labelWeights must be non-negative.")));
 
       Ops ltf =
           getTF()
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java
index 09581c726d3..60687dd9005 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java
@@ -17,11 +17,11 @@
 /**
  * Specifies the Riemann summation method used. {@link #INTERPOLATION} (default) applies mid-point
  * summation scheme for ROC. For PR-AUC, interpolates (true/false) positives but not the ratio that
- * is precision (see Davis & Goadrich 2006 for details); {@link #MINORING} applies left summation
- * for increasing intervals and right summation for decreasing intervals; {@link #MAJORING} does the
- * opposite.
+ * is precision (see Davis & Goadrich 2006 for details); {@link #MINORING} applies left
+ * summation for increasing intervals and right summation for decreasing intervals; {@link
+ * #MAJORING} does the opposite.
  *
- * @see Davis & Goadrich. 2006
+ * @see Davis & Goadrich. 2006
  * @see Riemann summation method
  */
 public enum AUCSummationMethod {
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java
index f69170e57b9..9548fb42c65 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java
@@ -32,7 +32,7 @@
  * ultimately returned as binary accuracy: an idempotent operation that simply divides total by
  * count.
  *
- * 

If sampleWeights is null>, weights default to 1. Use sampleWeights of 0 to mask + *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask * values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index 9e7f0f874cc..d2a414fdeb7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -30,7 +30,7 @@ * ultimately returned as binary accuracy: an idempotent operation that simply divides total by * count. * - *

If sampleWeights is null>, weights default to 1. Use sampleWeights of 0 to mask + *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask * values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java index cf6f84af512..39d33dda665 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -26,13 +26,12 @@ * This metric creates one local variable, accumulator that is used to keep track of * the number of false negatives. * - *

If sampleWeightsnull - * sampleWeightsIf sampleWeights is null, weights default to 1. Use + * sampleWeights of 0 to mask values. + * * @param The data type for the metric result */ -public class FalseNegatives - extends ConfusionMatrixConditionCount { +public class FalseNegatives extends ConfusionMatrixConditionCount { /** * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name and a diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java index 629caaafb52..3cf9fc0a5e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -26,14 +26,12 @@ * This metric creates one local variable, accumulator that is used to keep track of * the number of false positives. * - *

If sampleWeightsnull - * sampleWeightsIf sampleWeights is null, weights default to 1. Use + * sampleWeights of 0 to mask values. * * @param The data type for the metric result */ -public class FalsePositives< T extends TNumber> - extends ConfusionMatrixConditionCount { +public class FalsePositives extends ConfusionMatrixConditionCount { /** * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name and a diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index c8205565802..19b13ed391c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -39,8 +39,8 @@ * / (true_positive + false_positive + false_negative). The predictions are accumulated in a * confusion matrix, weighted by sample_weight and the metric is then calculated from it. * - *

If sampleWeight is null, weights default to 1. Use sample_weight of 0 to mask - * values. + *

If sampleWeight is null, weights default to 1. Use sample_weight of + * 0 to mask values. * * @param The data type for the metric result */ @@ -72,7 +72,7 @@ protected MeanIoU(Ops tf, long numClasses, long seed, Class type) { } /** - * create a metric with reduction = AUTO + * Creates a MeanIoU metric * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index eb8ccaf76d2..4c48c0f88a7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -42,17 +42,20 @@ public class MeanRelativeError extends Mean { private Operand normalizer; /** - * create a metric with name = class name and reduction = AUTO + * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name * * @param tf the TensorFlow Ops * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ protected MeanRelativeError(Ops tf, float[] normalizer, long seed, Class type) { this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); } /** - * create a metric with reduction = AUTO + * Creates a MeanRelativeError metric * * @param tf the TensorFlow Ops * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. @@ -66,7 +69,7 @@ protected MeanRelativeError(Ops tf, String name, float[] normalizer, long seed, } /** - * Creates a MeanRelativeError metric with name = class name and reduction = AUTO + * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name * * @param tf the TensorFlow Ops * @param normalizer The normalizer values with same shape as predictions. @@ -79,7 +82,7 @@ protected MeanRelativeError(Ops tf, double[] normalizer, long seed, Class typ } /** - * create a metric with reduction = AUTO + * Creates a MeanRelativeError metric * * @param tf the TensorFlow Ops * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. @@ -93,7 +96,7 @@ protected MeanRelativeError(Ops tf, String name, double[] normalizer, long seed, } /** - * create a metric with name = class name and reduction = AUTO + * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name * * @param tf the TensorFlow Ops * @param normalizer The normalizer values with same shape as predictions. @@ -106,7 +109,7 @@ protected MeanRelativeError(Ops tf, Operand normalizer, long seed, Class t } /** - * create a metric + * Creates a MeanRelativeError metric * * @param tf the TensorFlow ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index d9c767965a6..3d6d8194aac 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -103,7 +103,7 @@ private boolean init(Shape shape) { } } - /** {@inheritDoc */ + /** {@inheritDoc} */ @Override public List updateStateList( Operand values, Operand sampleWeights) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index 6b70c6680cb..ee87cebfa48 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -31,19 +31,22 @@ /** * Computes the precision of the predictions with respect to the labels. * - *

The metric creates two local variables, truePositives and falsePositives that are used to - * compute the precision. This value is ultimately returned as precision, an idempotent operation - * that simply divides truePositives by the sum of truePositives and falsePositives. + *

The metric creates two local variables, truePositives and falsePositives + * that are used to compute the precision. This value is ultimately returned as precision, + * an idempotent operation that simply divides truePositives by the sum of + * truePositives and falsePositives. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values. + *

If sampleWeights is null, weights default to 1. Use sampleWeights of + * 0 to mask values. * - *

If is set, the metric calculates precision as how often on average a class among the top-k - * classes with the highest predicted values of a batch entry is correct and can be found in the - * label for that entry. + *

If topK is set, the metric calculates precision as how often on average a class + * among the top-k classes with the highest predicted values of a batch entry is correct and can be + * found in the label for that entry. * - *

If classId is specified, the metric calculates precision by considering only the entries in the batch - * for which classId is above the thresholds and/or in the top-k highest predictions, and computing - * the fraction of them for which classId is indeed a correct label. + *

If classId is specified, the metric calculates precision by considering only the + * entries in the batch for which classId is above the thresholds and/or + * in the top-k highest predictions, and computing the fraction of them for which classId + * is indeed a correct label. * * @param The data type for the metric result */ @@ -58,13 +61,13 @@ public class Precision extends Metric { private final String truePositivesName; private final String falsePositivesName; private final Class type; + private final List initializers = new ArrayList<>(); private Variable truePositives; private Variable falsePositives; - private final List initializers = new ArrayList<>(); /** * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId - * values and with a threshold of {@link #DEFAULT_THRESHOLD).} + * values and with a threshold of {@link #DEFAULT_THRESHOLD}. * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed @@ -77,7 +80,7 @@ public Precision(Ops tf, long seed, Class type) { /** * Creates a Precision Metric with no topK or classId values with a threshold of {@link - * #DEFAULT_THRESHOLD).} + * #DEFAULT_THRESHOLD}. * * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link @@ -276,11 +279,8 @@ private void init() { Operand zero = zeros.call(tf.constant(Shape.of(thresholds.length)), type); if (this.truePositives == null) { - this.truePositives = - tf.withName(truePositivesName) - .variable(zero); + this.truePositives = tf.withName(truePositivesName).variable(zero); initializers.add(tf.assign(truePositives, zero)); - } if (this.falsePositives == null) { this.falsePositives = @@ -293,8 +293,10 @@ private void init() { /** {@inheritDoc} */ @Override @SuppressWarnings("unchecked") - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { Map> confusionMatrix = new HashMap<>(); confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives); @@ -314,7 +316,7 @@ public List updateStateList( thresholds, topK, classId, - tSampleWeights, + tSampleWeights, false, null)); } @@ -323,8 +325,7 @@ public List updateStateList( @Override public Operand result() { Ops tf = getTF(); - Operand result = - tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + Operand result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); return thresholds.length == 1 ? tf.slice( result, @@ -351,7 +352,7 @@ public float[] getThresholds() { /** * Gets the topK value, may be null * - * @return the topK + * @return the topK value or null */ public Integer getTopK() { return topK; @@ -360,7 +361,7 @@ public Integer getTopK() { /** * Gets the classId, may be null * - * @return the classId + * @return the classId or null */ public Integer getClassId() { return classId; @@ -375,7 +376,11 @@ public Variable getTruePositives() { return truePositives; } - /** Gets the falsePositives variable return the falsePositives */ + /** + * Gets the falsePositives variable + * + * @return the falsePositives + */ public Variable getFalsePositives() { return falsePositives; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 2ec66df0ca9..299c649279f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -24,10 +24,17 @@ /** * Computes best precision where recall is >= specified value. + * + *

This metric creates four local variables, truePositives, trueNegatives, falsePositives and + * falseNegatives that are used to compute the precision at the given recall. The threshold for the + * given recall value is computed and used to evaluate the corresponding precision. + * + *

If sampleWeights is null, weights default to 1. Use sampleWeights of + * 0 to mask values. + * * @param The data type for the metric result */ -public class PrecisionAtRecall - extends SensitivitySpecificityBase { +public class PrecisionAtRecall extends SensitivitySpecificityBase { private final float recall; @@ -40,7 +47,8 @@ public class PrecisionAtRecall * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public PrecisionAtRecall(Ops tf, float recall, long seed, Class type) { this(tf, null, recall, DEFAULT_NUM_THRESHOLDS, seed, type); @@ -56,7 +64,8 @@ public PrecisionAtRecall(Ops tf, float recall, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public PrecisionAtRecall(Ops tf, String name, float recall, long seed, Class type) { this(tf, name, recall, DEFAULT_NUM_THRESHOLDS, seed, type); @@ -72,7 +81,8 @@ public PrecisionAtRecall(Ops tf, String name, float recall, long seed, Class * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Class type) { this(tf, null, recall, numThresholds, seed, type); @@ -89,7 +99,8 @@ public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Cla * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public PrecisionAtRecall( Ops tf, String name, float recall, int numThresholds, long seed, Class type) { @@ -104,8 +115,7 @@ public Operand result() { Ops tf = getTF(); Operand recall = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); Operand sub = tf.math.sub(recall, cast(tf, tf.constant(value), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); @@ -115,7 +125,11 @@ public Operand result() { return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); } - /** @return the recall */ + /** + * Gets the recall value + * + * @return the recall value + */ public float getRecall() { return recall; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 0672b78f229..e1eebb98f77 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -30,22 +30,25 @@ /** * Computes the recall of the predictions with respect to the labels. - *

This metric creates two local - * variables, truePositives and falseNegatives, that are used to compute the recall. This value is - * ultimately returned as recall, an idempotent operation that simply divides truePositives by the sum of truePositives and falseNegatives. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values. + *

This metric creates two local variables, truePositives and falseNegatives + * , that are used to compute the recall. This value is ultimately returned as recall, an + * idempotent operation that simply divides truePositives by the sum of + * truePositives and falseNegatives. * - *

If is set, the metric calculates recall as how often on average a class among the labels of a - * batch entry is in the top-k predictions. + *

If sampleWeights is null, weights default to 1. Use sampleWeights of + * 0 to mask values. * - *

If classId is specified, the metric calculates recall by considering only the entries in the batch - * for which classId is in the label, and computing the fraction of them for which classId is above - * the threshold and/or in the top-k predictions. + *

If topK is set, the metric calculates recall as how often on average a class + * among the labels of a batch entry is in the top-k predictions. + * + *

If classId is specified, the metric calculates recall by considering only the + * entries in the batch for which classId is in the label, and computing the fraction + * of them for which classId is above the threshold and/or in the top-k predictions. * * @param The data type for the metric result */ -public class Recall< T extends TNumber> extends Metric< T> { +public class Recall extends Metric { public static final float DEFAULT_THRESHOLD = 0.5f; public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; @@ -56,9 +59,9 @@ public class Recall< T extends TNumber> extends Metric< T> { private final String truePositivesName; private final String falseNegativesName; private final Class type; + private final List initializers = new ArrayList<>(); private Variable truePositives; private Variable falseNegatives; - private final List initializers = new ArrayList<>(); /** * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set @@ -301,17 +304,13 @@ private void init() { Operand zero = zeros.call(tf.constant(Shape.of(this.thresholds.length)), type); if (truePositives == null) { - truePositives = - tf.withName(truePositivesName) - .variable(zero); + truePositives = tf.withName(truePositivesName).variable(zero); initializers.add(tf.assign(truePositives, zero)); } - + if (this.falseNegatives == null) { - falseNegatives = - tf.withName(falseNegativesName) - .variable(zero); + falseNegatives = tf.withName(falseNegativesName).variable(zero); initializers.add(tf.assign(falseNegatives, zero)); } } @@ -326,7 +325,9 @@ public Op resetStates() { @Override @SuppressWarnings("unchecked") public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, + Operand predictions, + Operand sampleWeights) { Map> confusionMatrix = new HashMap<>(); confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives); @@ -345,17 +346,16 @@ public List updateStateList( this.thresholds, this.topK, this.classId, - tSampleWeights, + tSampleWeights, false, null); } @Override public Operand result() { - Ops tf = getTF(); + Ops tf = getTF(); Operand result = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); return this.thresholds.length == 1 ? tf.slice(result, tf.constant(new int[] {0}), tf.constant(new int[1])) : result; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index 6c774f0c765..fb6890d1e01 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -24,8 +24,22 @@ import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; import static org.tensorflow.framework.utils.CastHelper.cast; -public class RecallAtPrecision - extends SensitivitySpecificityBase { +/** + * Computes best recall where precision is >= specified value. + * + *

For a given score-label-distribution the required precision might not be achievable, in this + * case 0.0 is returned as recall. + * + *

This metric creates four local variables, truePositives, trueNegatives, falsePositives and + * falseNegatives that are used to compute the recall at the given precision. The threshold for the + * given precision value is computed and used to evaluate the corresponding recall. + * + *

If sampleWeights is null, weights default to 1. Use sampleWeights of + * 0 to mask values. + * + * @param The data type for the metric result + */ +public class RecallAtPrecision extends SensitivitySpecificityBase { private final float precision; @@ -38,7 +52,8 @@ public class RecallAtPrecision * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public RecallAtPrecision(Ops tf, float precision, long seed, Class type) { this(tf, null, precision, DEFAULT_NUM_THRESHOLDS, seed, type); @@ -54,7 +69,8 @@ public RecallAtPrecision(Ops tf, float precision, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public RecallAtPrecision(Ops tf, String name, float precision, long seed, Class type) { this(tf, name, precision, DEFAULT_NUM_THRESHOLDS, seed, type); @@ -70,7 +86,8 @@ public RecallAtPrecision(Ops tf, String name, float precision, long seed, Class< * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, Class type) { this(tf, null, precision, numThresholds, seed, type); @@ -87,7 +104,8 @@ public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public RecallAtPrecision( Ops tf, String name, float precision, int numThresholds, long seed, Class type) { @@ -103,18 +121,15 @@ public Operand result() { Ops tf = getTF(); Operand precisions = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falsePositives)); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falsePositives)); Operand recalls = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); Operand isFeasible = tf.math.greaterEqual(precisions, cast(tf, tf.constant(this.value), getType())); Where feasible = tf.where(isFeasible); Operand feasibleExists = tf.math.greater(tf.size(feasible), tf.constant(0)); - Operand gather = - tf.expandDims(tf.gather(recalls, feasible, tf.constant(0)), tf.constant(0)); + Operand gather = tf.expandDims(tf.gather(recalls, feasible, tf.constant(0)), tf.constant(0)); return tf.select( feasibleExists, tf.reduceMax(gather, allAxes(tf, gather)), diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 2133642564b..9b4401964d7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -27,12 +27,12 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes root mean squared error metric between labels> and predictions + * Computes root mean squared error metric between labels and predictions * . * * @param The data type for the metric result */ -public class RootMeanSquaredError< T extends TNumber> extends Mean< T> { +public class RootMeanSquaredError extends Mean { /** * Creates a RootMeanSquaredError metric with a name of {@link Class#getSimpleName()} @@ -62,12 +62,15 @@ public RootMeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); - Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + Operand tSampleWeights = + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; LossTuple ops = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); tPredictions = ops.getTarget(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 7cf694868e6..2c7420a5518 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -23,7 +23,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes best sensitivity where sensitivity is >= specified value. + * Computes best sensitivity where sensitivity is >= specified value. * *

Sensitivity measures the proportion of actual positives that are correctly * identified as such (tp / (tp + fn)). @@ -36,15 +36,14 @@ * sensitivity at the given specificity. The threshold for the given specificity value is computed * and used to evaluate the corresponding sensitivity. * - *

If sampleWeights is null>, weights default to 1. Use sample_weight - * of 0 to mask values. + *

If sampleWeights is null, weights default to 1. Use sample_weight of + * 0 to mask values. * * @see Additional information * about specificity and sensitivity * @param The data type for the metric result */ -public class SensitivityAtSpecificity - extends SensitivitySpecificityBase { +public class SensitivityAtSpecificity extends SensitivitySpecificityBase { private final float specificity; @@ -57,7 +56,7 @@ public class SensitivityAtSpecificity * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ public SensitivityAtSpecificity(Ops tf, float specificity, long seed, Class type) { @@ -74,7 +73,7 @@ public SensitivityAtSpecificity(Ops tf, float specificity, long seed, Class t * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ public SensitivityAtSpecificity( @@ -92,7 +91,7 @@ public SensitivityAtSpecificity( * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ public SensitivityAtSpecificity( @@ -111,7 +110,7 @@ public SensitivityAtSpecificity( * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ public SensitivityAtSpecificity( @@ -127,10 +126,8 @@ public SensitivityAtSpecificity( public Operand result() { Ops tf = getTF(); Operand specificities = - tf.math.divNoNan( - this.trueNegatives, tf.math.add(this.trueNegatives, this.falsePositives)); - Operand sub = - tf.math.sub(specificities, cast(tf, tf.constant(this.getValue()), getType())); + tf.math.divNoNan(this.trueNegatives, tf.math.add(this.trueNegatives, this.falsePositives)); + Operand sub = tf.math.sub(specificities, cast(tf, tf.constant(this.getValue()), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 156a4995b02..7034861d8d2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -31,24 +31,23 @@ /** * Calculates how often predictions matches integer labels. * - *

You can provide logits of classes as predictions, since argmax of logits and probabilities are - * same. + *

You can provide logits of classes as predictions, since argmax of logits and + * probabilities are same. * *

This metric creates two local variables, `total` and `count` that are used to compute the - * frequency with which predictions matches labels. This frequency is ultimately returned as `sparse - * categorical accuracy`: an idempotent operation that simply divides `total` by `count`. + * frequency with which predictions matches labels. This frequency is + * ultimately returned as `sparse categorical accuracy`: an idempotent operation that simply divides + * `total` by `count`. * *

If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values.' * *

Usage: * - *

- * *

  * SparseCategoricalAccuracy m = new tf.keras.metrics.SparseCategoricalAccuracy();
  * m.update_state(tf.constant(new float[][] {{2}, {1}},
  *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, [{0.05f, 0.95f, 0f}});
- * Operand<TFloat32>> result = m.result();
+ * Operand<TFloat32> result = m.result();
  * System.out.println(result.data().getFloat());
  * 0.5
  * 
@@ -87,7 +86,7 @@ public class SparseCategoricalAccuracy extends MeanMetricWrap * will always produce the same random tensor for a given shape and data type. * @param type The data type for the metric result */ - public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { + public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { this(tf, null, seed, type); } @@ -100,7 +99,7 @@ public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the type of the metric result. */ - public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); super.setLoss(this); } @@ -108,8 +107,7 @@ public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) /** {@inheritDoc} */ @Override public Operand call( - Operand labels, - Operand predictions) { + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index 59f6f44c1f2..d0b797690bd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -23,7 +23,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes best specificity where sensitivity is >= specified value. Sensitivity + * Computes best specificity where sensitivity is >= specified value. Sensitivity * measures the proportion of actual positives that are correctly identified as such * (tp / (tp + fn)). * @@ -35,15 +35,14 @@ * specificity at the given sensitivity. The threshold for the given sensitivity value is computed * and used to evaluate the corresponding specificity. * - *

If sampleWeights is null>, weights default to 1. Use sample_weight - * of 0 to mask values. + *

If sampleWeights is null, weights default to 1. Use sample_weight of + * 0 to mask values. * * @see Additional information * about specificity and sensitivity * @param The data type for the metric result */ -public class SpecificityAtSensitivity - extends SensitivitySpecificityBase { +public class SpecificityAtSensitivity extends SensitivitySpecificityBase { private final float sensitivity; @@ -56,7 +55,7 @@ public class SpecificityAtSensitivity * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ public SpecificityAtSensitivity(Ops tf, float sensitivity, long seed, Class type) { @@ -73,7 +72,7 @@ public SpecificityAtSensitivity(Ops tf, float sensitivity, long seed, Class t * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ public SpecificityAtSensitivity( @@ -91,7 +90,7 @@ public SpecificityAtSensitivity( * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ public SpecificityAtSensitivity( @@ -110,7 +109,7 @@ public SpecificityAtSensitivity( * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ public SpecificityAtSensitivity( @@ -124,14 +123,12 @@ public SpecificityAtSensitivity( /** {@inheritDoc} */ @Override public Operand result() { - + Ops tf = getTF(); Operand sensitivities = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); - Operand sub = - tf.math.sub(sensitivities, cast(tf, tf.constant(this.getValue()), getType())); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + Operand sub = tf.math.sub(sensitivities, cast(tf, tf.constant(this.getValue()), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java index 4312d7a97f0..a3241221b66 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -28,8 +28,6 @@ * values. This is ultimately returned as sum. * *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. - * - */ public class Sum extends Reduce { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index d2db4f368ac..ad78e48bc34 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -22,12 +22,13 @@ import static org.tensorflow.framework.utils.CastHelper.cast; -/** Computes the poisson loss metric between labels and predictions. +/** + * Computes the poisson loss metric between labels and predictions. * * @param The data type for the metric result */ -public class TopKCategoricalAccuracy - extends MeanMetricWrapper implements LossMetric { +public class TopKCategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { public static final int DEFAULT_K = 5; /** Number of top elements to look at for computing accuracy. */ private final int k; @@ -40,6 +41,7 @@ public class TopKCategoricalAccuracy * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result */ public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { this(tf, name, DEFAULT_K, seed, type); @@ -53,6 +55,7 @@ public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { * @param k Number of top elements to look at for computing accuracy. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result */ public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { super(tf, name, seed, type); @@ -62,7 +65,8 @@ public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class t /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Metrics.topKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java index de6428fed88..91b6751588a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -26,14 +26,12 @@ * This metric creates one local variable, accumulator that is used to keep track of * the number of true negatives. * - *

If sampleWeightsnull, weights - * default to 1. Use + *

If sampleWeights is null, weights default to 1. Use * sampleWeights of 0 to mask values. * * @param The data type for the metric result */ -public class TrueNegatives - extends ConfusionMatrixConditionCount { +public class TrueNegatives extends ConfusionMatrixConditionCount { /** * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name and a diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java index c573b6b5719..b67d381a62d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -26,13 +26,12 @@ * This metric creates one local variable, accumulator that is used to keep track of * the number of true positives. * - *

If sampleWeightsnull, weights - * default to 1. Use + *

If sampleWeights is null, weights default to 1. Use * sampleWeights of 0 to mask values. + * * @param The data type for the metric result */ -public class TruePositives - extends ConfusionMatrixConditionCount< T> { +public class TruePositives extends ConfusionMatrixConditionCount { /** * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name and a diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java index b76356661a9..281aa2072d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java @@ -27,7 +27,12 @@ public enum ConfusionMatrixEnum { private final String abbrev; - /** Creates a ConfusionMatrixEnum */ + /** + * Creates a ConfusionMatrixEnum + * + * @param abbrev the abbreviation for the confusion condition as required by the underlying + * TensorFlow api. + */ ConfusionMatrixEnum(String abbrev) { this.abbrev = abbrev; } @@ -50,7 +55,11 @@ public static ConfusionMatrixEnum get(String item) { return null; } - /** Gets the abbreviation for this enum value */ + /** + * Gets the abbreviation for this enum value + * + * @return the abbreviation for this enum value as required by the underlying TensorFlow api. + */ public String getAbbreviation() { return abbrev; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index cbb24933967..0be0a7a572a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -283,10 +283,10 @@ public static List assertShapes( *

For every pair of values in labels and predictions: * *

-   * TRUE_POSITIVES:  labels == true and predictions > thresholds
-   * FALSE_POSITIVES: labels == true and predictions <= thresholds
-   * TRUE_NEGATIVES:  labels == false and predictions <= thresholds
-   * FALSE_NEGATIVE:  labels == false and predictions > thresholds
+   * TRUE_POSITIVES:  labels == true and predictions > thresholds
+   * FALSE_POSITIVES: labels == true and predictions <= thresholds
+   * TRUE_NEGATIVES:  labels == false and predictions <= thresholds
+   * FALSE_NEGATIVE:  labels == false and predictions > thresholds
    * 
* *

The results will be weighted and added together. When multiple thresholds are provided, we @@ -324,7 +324,7 @@ public static List assertShapes( * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. * @param the data type for the variables * @throws IllegalArgumentException If predictions and labels have - * mismatched shapes, or if sampleWeight is not null>and its shape + * mismatched shapes, or if sampleWeight is not nulland its shape * doesn't match predictions * @return an op to update the given confusion matrix variables. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 3949ede822a..377124333bd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -55,7 +55,7 @@ public abstract class SensitivitySpecificityBase extends Metr * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0. + * @throws IllegalArgumentException if numThresholds <= 0. */ protected SensitivitySpecificityBase( Ops tf, String name, float value, int numThresholds, long seed, Class type) { @@ -114,28 +114,29 @@ private void init() { public Op initializeVariables() { List varInitializers = new ArrayList<>(); - if(truePositivesInitializer != null ) { + if (truePositivesInitializer != null) { varInitializers.add(truePositivesInitializer); } - if(falsePositivesInitializer != null ) { + if (falsePositivesInitializer != null) { varInitializers.add(falsePositivesInitializer); } - if(trueNegativesInitializer != null ) { + if (trueNegativesInitializer != null) { varInitializers.add(trueNegativesInitializer); } - if(falseNegativesInitializer != null ) { + if (falseNegativesInitializer != null) { varInitializers.add(falseNegativesInitializer); } return getTF().withControlDependencies(varInitializers).noOp(); - } /** {@inheritDoc} */ @Override @SuppressWarnings("unchecked") - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { Ops tf = getTF(); Operand tLabels = cast(tf, labels, type); Operand tPredictions = cast(tf, predictions, type); @@ -156,7 +157,7 @@ public List updateStateList( this.getThresholds(), null, null, - tSampleWeights, + tSampleWeights, false, null); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 09752798ad5..36792b8ea7a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -153,7 +153,7 @@ private static Operand hasValidDims( * *

This returns a version of weights following the same broadcast rules as * mul(weights, - * values), but limited to the weights shapes allowed by assertBroadcastable + * values), but limited to the weights shapes allowed by assertBroadcastable * When computing a weighted average, use this function to broadcast weights before * summing them; e.g., reduceSum(w * v) / reduceSum(_broadcast_weights(w, v)). * From 2c938840e9b526f0100cc188c6ad200328137ad4 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 2 Mar 2021 10:45:02 -0500 Subject: [PATCH 08/97] Change thresholds to Operand --- .../org/tensorflow/framework/metrics/AUC.java | 309 ++++++++---------- .../framework/metrics/Precision.java | 16 +- .../tensorflow/framework/metrics/Recall.java | 16 +- .../impl/ConfusionMatrixConditionCount.java | 9 +- .../framework/metrics/impl/MetricsHelper.java | 113 ++++--- .../impl/SensitivitySpecificityBase.java | 10 +- .../framework/metrics/PrecisionTest.java | 2 +- 7 files changed, 213 insertions(+), 262 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index da89167e1f3..8a31dfd3fce 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -58,9 +58,9 @@ *

Usage:
* *

- * AUC m = new  getTF().keras.metrics.AUC( getTF(), 3);
- * m.updateState( getTF().constant(new float[] {0, 0, 1,1}),
- *          getTF().constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
+ * AUC m = new  tf.keras.metrics.AUC( tf, 3);
+ * m.updateState( tf.constant(new float[] {0, 0, 1,1}),
+ *          tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
  *
  * // threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
  * // tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
@@ -73,9 +73,9 @@
  *
  * 
  * m.resetStates()
- * m.updateState( getTF().constant(new float[] {0, 0, 1, 1}),
- *                 getTF().constant(new float[] {0f, 0.5f, 0.3f, 0.9f}, ),
- *                 getTF().constant(new float[] {1, 0, 0, 1}));
+ * m.updateState( tf.constant(new float[] {0, 0, 1, 1}),
+ *                 tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}, ),
+ *                 tf.constant(new float[] {1, 0, 0, 1}));
  * result = m.result();
  * System.out.println(result.data().getFloat());
  * 1.0
@@ -209,7 +209,7 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) {
     this(
         tf,
         null,
-        null,
+            DEFAULT_NUM_THRESHOLDS,
         AUCCurve.ROC,
         AUCSummationMethod.INTERPOLATION,
         thresholds,
@@ -264,7 +264,7 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) {
     this(
         tf,
         name,
-        null,
+            DEFAULT_NUM_THRESHOLDS,
         AUCCurve.ROC,
         AUCSummationMethod.INTERPOLATION,
         thresholds,
@@ -322,7 +322,7 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C
     this(
         tf,
         name,
-        null,
+            DEFAULT_NUM_THRESHOLDS,
         curve,
         AUCSummationMethod.INTERPOLATION,
         thresholds,
@@ -378,7 +378,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type)
     this(
         tf,
         null,
-        null,
+            DEFAULT_NUM_THRESHOLDS,
         curve,
         AUCSummationMethod.INTERPOLATION,
         thresholds,
@@ -435,7 +435,7 @@ public AUC(
       AUCSummationMethod summationMethod,
       long seed,
       Class type) {
-    this(tf, null, null, curve, summationMethod, thresholds, false, null, seed, type);
+    this(tf, null, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type);
   }
 
   /**
@@ -487,7 +487,7 @@ public AUC(
       AUCSummationMethod summationMethod,
       long seed,
       Class type) {
-    this(tf, name, null, curve, summationMethod, thresholds, false, null, seed, type);
+    this(tf, name, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type);
   }
 
   /**
@@ -496,7 +496,7 @@ public AUC(
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}.
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS}
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param summationMethod Specifies the Riemann summation method used
@@ -520,7 +520,7 @@ public AUC(
   public AUC(
       Ops tf,
       String name,
-      Integer numThresholds,
+      int numThresholds,
       AUCCurve curve,
       AUCSummationMethod summationMethod,
       float[] thresholds,
@@ -529,10 +529,10 @@ public AUC(
       long seed,
       Class type) {
     super(tf, name == null ? DEFAULT_NAME : name, seed);
-    this.truePositivesName = this.getVariableName(TRUE_POSITIVES);
-    this.falsePositivesName = this.getVariableName(FALSE_POSITIVES);
-    this.trueNegativesName = this.getVariableName(TRUE_NEGATIVES);
-    this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES);
+    truePositivesName = getVariableName(TRUE_POSITIVES);
+    falsePositivesName = getVariableName(FALSE_POSITIVES);
+    trueNegativesName = getVariableName(TRUE_NEGATIVES);
+    falseNegativesName = getVariableName(FALSE_NEGATIVES);
     this.curve = curve;
     this.summationMethod = summationMethod;
     this.type = type;
@@ -540,18 +540,23 @@ public AUC(
     this.multiLabel = multiLabel;
 
     if (thresholds != null) { // ignore numThresholds
-      for (float t : thresholds)
-        if (t < 0.0f || t > 1.0f)
+      for (float t : thresholds) {
+        if (t < 0.0f || t > 1.0f) {
           throw new IllegalArgumentException(
               String.format(
                   "Threshold values must be in [0, 1]. Invalid values: %s",
                   Arrays.toString(thresholds)));
+        }
+      }
       this.numThresholds = thresholds.length + 2;
       Arrays.sort(thresholds);
     } else {
-      if (numThresholds <= 1) throw new IllegalArgumentException("numThresholds must be > 1.");
+
+      if (numThresholds <= 1) {
+        throw new IllegalArgumentException("numThresholds must be > 1.");
+      }
       this.numThresholds = numThresholds;
-      thresholds = new float[numThresholds - 2];
+      thresholds = new float[this.numThresholds - 2];
       // linearly interpolate (numThresholds - 2) thresholds between endpoints
       for (int i = 0; i < thresholds.length; i++) {
         thresholds[i] = (i + 1) * 1.0f / (this.numThresholds - 1);
@@ -559,39 +564,38 @@ public AUC(
     }
     // Add an endpoint "threshold" below zero and above one for either
     // threshold method to account for floating point imprecision.
-    if (thresholds.length != this.numThresholds - 2)
+    if (thresholds.length != this.numThresholds - 2) {
       throw new IllegalArgumentException(
           "Thresholds length must contain numThresholds - 2 entries");
+    }
+    // Add an endpoint "threshold" below zero and above one for either
+    // threshold method to account for floating point imprecisions.
     this.thresholds = new float[this.numThresholds];
     this.thresholds[0] = -EPSILON;
     System.arraycopy(thresholds, 0, this.thresholds, 1, thresholds.length);
     this.thresholds[this.numThresholds - 1] = 1 + EPSILON;
 
+    // # Handle multilabel arguments.
+
     if (labelWeights != null) {
       // assert that labelWeights are non-negative.
 
       this.labelWeights = labelWeights;
       Op checks =
-          getTF()
-              .withSubScope("AUC")
+          tf.withSubScope("AUC")
               .assertThat(
-                  getTF()
-                      .math
-                      .greaterEqual(
-                          labelWeights, cast(getTF(), getTF().constant(0), labelWeights.type())),
+                  tf.math.greaterEqual(labelWeights, cast(tf, tf.constant(0), labelWeights.type())),
                   Collections.singletonList(
-                      getTF().constant("All values of labelWeights must be non-negative.")));
+                      tf.constant("All values of labelWeights must be non-negative.")));
 
       Ops ltf =
-          getTF()
-              .withSubScope("updateState")
-              .withControlDependencies(Collections.singletonList(checks));
+          tf.withSubScope("updateState").withControlDependencies(Collections.singletonList(checks));
 
       this.labelWeights = ltf.identity(this.labelWeights);
     }
 
-    if (this.multiLabel) {
-      this.numLabels = null;
+    if (multiLabel) {
+      numLabels = null;
     }
   }
 
@@ -607,6 +611,7 @@ private Map> build(Shape shape) {
     if (initialized) {
       return Collections.EMPTY_MAP;
     }
+    Ops tf = getTF();
 
     if (this.isMultiLabel()) {
       if (shape == null) {
@@ -623,26 +628,27 @@ private Map> build(Shape shape) {
       variableShape = Shape.of(this.numThresholds);
     }
 
+    // Create metric variables
     Zeros zeros = new Zeros<>(getTF());
-    Operand zero = zeros.call(getTF().constant(variableShape), type);
+    Operand zero = zeros.call(tf.constant(variableShape), type);
     if (truePositives == null) {
-      truePositives = getTF().withName(getTruePositivesName()).variable(zero);
-      initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, getTF().assign(truePositives, zero));
+      truePositives = tf.withName(getTruePositivesName()).variable(zero);
+      initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, tf.assign(truePositives, zero));
     }
 
     if (falsePositives == null) {
-      falsePositives = getTF().withName(getFalsePositivesName()).variable(zero);
-      initializers.put(ConfusionMatrixEnum.FALSE_POSITIVES, getTF().assign(falsePositives, zero));
+      falsePositives = tf.withName(getFalsePositivesName()).variable(zero);
+      initializers.put(ConfusionMatrixEnum.FALSE_POSITIVES, tf.assign(falsePositives, zero));
     }
 
     if (trueNegatives == null) {
-      trueNegatives = getTF().withName(getTrueNegativesName()).variable(zero);
-      initializers.put(ConfusionMatrixEnum.TRUE_NEGATIVES, getTF().assign(trueNegatives, zero));
+      trueNegatives = tf.withName(getTrueNegativesName()).variable(zero);
+      initializers.put(ConfusionMatrixEnum.TRUE_NEGATIVES, tf.assign(trueNegatives, zero));
     }
 
     if (falseNegatives == null) {
-      falseNegatives = getTF().withName(getFalseNegativesName()).variable(zero);
-      initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, getTF().assign(falseNegatives, zero));
+      falseNegatives = tf.withName(getFalseNegativesName()).variable(zero);
+      initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, tf.assign(falseNegatives, zero));
     }
 
     this.initialized = true;
@@ -656,19 +662,22 @@ public List updateStateList(
       Operand labels,
       Operand predictions,
       Operand sampleWeights) {
-
-    Operand lLabels = cast(getTF(), labels, type);
-    Operand lPredictions = cast(getTF(), predictions, type);
-    Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null;
+    Ops tf = getTF();
+    Operand tLabels = cast(tf, labels, type);
+    Operand tPredictions = cast(tf, predictions, type);
+    Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
     List updateOperations = new ArrayList<>();
     Map> varInitializers = Collections.EMPTY_MAP;
     if (!this.initialized) {
-      varInitializers = build(lPredictions.shape());
+      varInitializers = build(tPredictions.shape());
     }
     if (this.isMultiLabel() || this.getLabelWeights() != null) {
+      // labels should have shape (number of examples, number of labels).
       List> symbols = new ArrayList<>();
-      symbols.add(new SymbolicShape<>(lLabels, "N", "L"));
+      symbols.add(new SymbolicShape<>(tLabels, "N", "L"));
       if (this.isMultiLabel()) {
+        // TP, TN, FP, and FN should all have shape
+        //(number of thresholds, number of labels).
         symbols.add(new SymbolicShape<>(this.truePositives, "T", "L"));
         symbols.add(new SymbolicShape<>(this.falsePositives, "T", "L"));
         symbols.add(new SymbolicShape<>(this.trueNegatives, "T", "L"));
@@ -678,30 +687,34 @@ public List updateStateList(
         symbols.add(new SymbolicShape<>(this.getLabelWeights(), "L", ""));
       }
       updateOperations.addAll(
-          MetricsHelper.assertShapes(getTF(), symbols, "Number of labels is not consistent."));
-    }
-    if (this.isMultiLabel()) {
-      this.labelWeights = null;
+          MetricsHelper.assertShapes(tf, symbols, "Number of labels is not consistent."));
     }
+
     Map> confusionMatrix = new HashMap<>();
     confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives);
     confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.falsePositives);
     confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.trueNegatives);
     confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives);
 
+    // Only forward labelWeights to update_confusion_matrix_variables when
+    // multiLabel is false. Otherwise the averaging of individual label AUCs is
+    // handled in AUC.result
+    if (this.isMultiLabel()) {
+      this.labelWeights = null;
+    }
     updateOperations.addAll(
         MetricsHelper.updateConfusionMatrixVariables(
-            getTF(),
+            tf,
             confusionMatrix,
             varInitializers,
-            lLabels,
-            lPredictions,
-            this.thresholds,
+            tLabels,
+            tPredictions,
+            tf.constant(thresholds),
             null,
             null,
             tSampleWeights,
-            this.isMultiLabel(),
-            this.getLabelWeights()));
+            isMultiLabel(),
+            getLabelWeights()));
     return updateOperations;
   }
 
@@ -712,147 +725,84 @@ public List updateStateList(
    */
   private Operand interpolatePRAuc() {
     // truePositives[:self.numThresholds - 1]
+    Ops tf = getTF();
     Operand tp0 =
-        getTF()
-            .slice(
-                truePositives,
-                getTF().constant(new int[] {0}),
-                getTF().constant(new int[] {this.getNumThresholds() - 1}));
+        tf.slice(
+            truePositives,
+            tf.constant(new int[] {0}),
+            tf.constant(new int[] {this.getNumThresholds() - 1}));
     // truePositives[1:]
     Operand tp1 =
-        getTF()
-            .slice(
-                truePositives, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1}));
+        tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
 
-    Operand dTP = getTF().math.sub(tp0, tp1);
+    Operand dTP = tf.math.sub(tp0, tp1);
 
-    Operand p = getTF().math.add(truePositives, falsePositives);
+    Operand p = tf.math.add(truePositives, falsePositives);
 
     Operand dP =
-        getTF()
-            .math
-            .sub(
-                getTF()
-                    .slice(
-                        p,
-                        getTF().constant(new int[] {0}),
-                        getTF().constant(new int[] {this.getNumThresholds() - 1})),
-                getTF()
-                    .slice(p, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1})));
+        tf.math.sub(
+            tf.slice(
+                p,
+                tf.constant(new int[] {0}),
+                tf.constant(new int[] {this.getNumThresholds() - 1})),
+            tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})));
 
     Operand precisionSlope =
-        getTF()
-            .math
-            .divNoNan(
-                dTP, getTF().math.maximum(dP, getTF().dtypes.cast(getTF().constant(0), dP.type())));
+        tf.math.divNoNan(dTP, tf.math.maximum(dP, tf.dtypes.cast(tf.constant(0), dP.type())));
 
     Operand intercept =
-        getTF()
-            .math
-            .sub(
-                getTF()
-                    .slice(
-                        truePositives,
-                        getTF().constant(new int[] {1}),
-                        getTF().constant(new int[] {-1})),
-                getTF()
-                    .math
-                    .mul(
-                        precisionSlope,
-                        getTF()
-                            .slice(
-                                p,
-                                getTF().constant(new int[] {1}),
-                                getTF().constant(new int[] {-1}))));
+        tf.math.sub(
+            tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
+            tf.math.mul(
+                precisionSlope,
+                tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))));
 
     Operand safePRatio =
-        getTF()
-            .select(
-                getTF()
-                    .math
-                    .logicalAnd(
-                        getTF()
-                            .math
-                            .greater(
-                                getTF()
-                                    .slice(
-                                        p,
-                                        getTF().constant(new int[] {0}),
-                                        getTF().constant(new int[] {this.getNumThresholds() - 1})),
-                                getTF().dtypes.cast(getTF().constant(0), p.type())),
-                        getTF()
-                            .math
-                            .greater(
-                                getTF()
-                                    .slice(
-                                        p,
-                                        getTF().constant(new int[] {1}),
-                                        getTF().constant(new int[] {-1})),
-                                getTF().dtypes.cast(getTF().constant(0), p.type()))),
-                getTF()
-                    .math
-                    .divNoNan(
-                        getTF()
-                            .slice(
-                                p,
-                                getTF().constant(new int[] {0}),
-                                getTF().constant(new int[] {this.getNumThresholds() - 1})),
-                        getTF()
-                            .math
-                            .maximum(
-                                getTF()
-                                    .slice(
-                                        p,
-                                        getTF().constant(new int[] {1}),
-                                        getTF().constant(new int[] {-1})),
-                                getTF().dtypes.cast(getTF().constant(0), p.type()))),
-                getTF()
-                    .onesLike(
-                        getTF()
-                            .slice(
-                                p,
-                                getTF().constant(new int[] {1}),
-                                getTF().constant(new int[] {-1}))));
+        tf.select(
+            tf.math.logicalAnd(
+                tf.math.greater(
+                    tf.slice(
+                        p,
+                        tf.constant(new int[] {0}),
+                        tf.constant(new int[] {this.getNumThresholds() - 1})),
+                    tf.dtypes.cast(tf.constant(0), p.type())),
+                tf.math.greater(
+                    tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
+                    tf.dtypes.cast(tf.constant(0), p.type()))),
+            tf.math.divNoNan(
+                tf.slice(
+                    p,
+                    tf.constant(new int[] {0}),
+                    tf.constant(new int[] {this.getNumThresholds() - 1})),
+                tf.math.maximum(
+                    tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
+                    tf.dtypes.cast(tf.constant(0), p.type()))),
+            tf.onesLike(tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))));
 
     Operand fn1 =
-        getTF()
-            .slice(
-                falseNegatives, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1}));
+        tf.slice(falseNegatives, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
 
     Operand aucTotalPos =
-        getTF()
-            .math
-            .mul(
-                precisionSlope,
-                getTF().math.add(dTP, getTF().math.mul(intercept, getTF().math.log(safePRatio))));
+        tf.math.mul(
+            precisionSlope, tf.math.add(dTP, tf.math.mul(intercept, tf.math.log(safePRatio))));
 
     Operand prAucIncrement =
-        getTF()
-            .math
-            .divNoNan(
-                aucTotalPos,
-                getTF()
-                    .math
-                    .maximum(
-                        getTF().math.add(tp1, fn1),
-                        getTF().dtypes.cast(getTF().constant(0), this.truePositives.type())));
+        tf.math.divNoNan(
+            aucTotalPos,
+            tf.math.maximum(
+                tf.math.add(tp1, fn1), tf.dtypes.cast(tf.constant(0), this.truePositives.type())));
 
     if (this.isMultiLabel()) {
-      Operand byLabelAuc = getTF().reduceSum(prAucIncrement, getTF().constant(0));
+      Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0));
       if (this.getLabelWeights() == null) {
-        return MetricsHelper.mean(getTF(), byLabelAuc);
+        return MetricsHelper.mean(tf, byLabelAuc);
       } else {
-        return getTF()
-            .math
-            .divNoNan(
-                getTF()
-                    .reduceSum(
-                        getTF().math.mul(byLabelAuc, this.getLabelWeights()),
-                        allAxes(getTF(), byLabelAuc)),
-                getTF().reduceSum(getLabelWeights(), allAxes(getTF(), getLabelWeights())));
+        return tf.math.divNoNan(
+            tf.reduceSum(tf.math.mul(byLabelAuc, this.getLabelWeights()), allAxes(tf, byLabelAuc)),
+            tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights())));
       }
     } else {
-      return getTF().reduceSum(prAucIncrement, allAxes(getTF(), prAucIncrement));
+      return tf.reduceSum(prAucIncrement, allAxes(tf, prAucIncrement));
     }
   }
 
@@ -862,13 +812,13 @@ public Operand result() {
 
     if (this.getCurve() == AUCCurve.PR
         && this.getSummationMethod() == AUCSummationMethod.INTERPOLATION) {
+      // This use case is different and is handled separately.
       return this.interpolatePRAuc();
     }
     Ops tf = getTF();
     Operand x;
     Operand y;
-    Operand recall =
-        getTF().math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives));
+    Operand recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives));
 
     if (this.getCurve() == AUCCurve.ROC) {
       x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives));
@@ -890,7 +840,7 @@ public Operand result() {
     switch (this.getSummationMethod()) {
       case INTERPOLATION:
         heights =
-            tf.math.div(tf.math.add(ySlice1, ySlice2), tf.dtypes.cast(tf.constant(2), y.type()));
+            tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type()));
         break;
       case MINORING:
         heights = tf.math.minimum(ySlice1, ySlice2);
@@ -915,6 +865,7 @@ public Operand result() {
       if (this.getLabelWeights() == null) {
         return MetricsHelper.mean(tf, byLabelAuc);
       } else {
+        //Weighted average of the label AUCs.
         return tf.math.divNoNan(
             tf.reduceSum(
                 tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())),
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java
index ee87cebfa48..bd536f16b29 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java
@@ -75,7 +75,7 @@ public class Precision extends Metric {
    * @param type the data type for the variables
    */
   public Precision(Ops tf, long seed, Class type) {
-    this(tf, null, new float[] {DEFAULT_THRESHOLD}, null, null, seed, type);
+    this(tf, null, null, null, null, seed, type);
   }
 
   /**
@@ -90,7 +90,7 @@ public Precision(Ops tf, long seed, Class type) {
    * @param type the data type for the variables
    */
   public Precision(Ops tf, String name, long seed, Class type) {
-    this(tf, name, new float[] {DEFAULT_THRESHOLD}, null, null, seed, type);
+    this(tf, name, null, null, null, seed, type);
   }
 
   /**
@@ -297,23 +297,23 @@ public List updateStateList(
       Operand labels,
       Operand predictions,
       Operand sampleWeights) {
-
+    Ops tf = getTF();
     Map> confusionMatrix = new HashMap<>();
     confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives);
     confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, falsePositives);
 
-    Operand tPredictions = cast(getTF(), predictions, type);
-    Operand tLabels = cast(getTF(), labels, type);
-    Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null;
+    Operand tPredictions = cast(tf, predictions, type);
+    Operand tLabels = cast(tf, labels, type);
+    Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
 
     return new ArrayList(
         MetricsHelper.updateConfusionMatrixVariables(
-            getTF(),
+            tf,
             confusionMatrix,
             Collections.EMPTY_MAP,
             tLabels,
             tPredictions,
-            thresholds,
+            tf.constant(thresholds),
             topK,
             classId,
             tSampleWeights,
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java
index e1eebb98f77..54e9de0d9cf 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java
@@ -328,24 +328,24 @@ public List updateStateList(
       Operand labels,
       Operand predictions,
       Operand sampleWeights) {
-
+    Ops tf = getTF();
     Map> confusionMatrix = new HashMap<>();
     confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives);
     confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives);
 
-    Operand tPredictions = cast(getTF(), predictions, type);
-    Operand tLabels = cast(getTF(), labels, type);
-    Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null;
+    Operand tPredictions = cast(tf, predictions, type);
+    Operand tLabels = cast(tf, labels, type);
+    Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
 
     return MetricsHelper.updateConfusionMatrixVariables(
-        getTF(),
+        tf,
         confusionMatrix,
         Collections.EMPTY_MAP,
         tLabels,
         tPredictions,
-        this.thresholds,
-        this.topK,
-        this.classId,
+        tf.constant(thresholds),
+        topK,
+        classId,
         tSampleWeights,
         false,
         null);
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java
index c9e762d05d4..31e88b6bb31 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java
@@ -140,9 +140,10 @@ public List updateStateList(
       Operand labels,
       Operand predictions,
       Operand sampleWeights) {
-    Operand tLabels = cast(getTF(), labels, type);
-    Operand tPredictions = cast(getTF(), predictions, type);
-    Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null;
+    Ops tf = getTF();
+    Operand tLabels = cast(tf, labels, type);
+    Operand tPredictions = cast(tf, predictions, type);
+    Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
     return new ArrayList<>(
         MetricsHelper.updateConfusionMatrixVariables(
             getTF(),
@@ -150,7 +151,7 @@ public List updateStateList(
             Collections.singletonMap(confusionMatrixCond, initializer),
             tLabels,
             tPredictions,
-            thresholds,
+            tf.constant(thresholds),
             null,
             null,
             tSampleWeights,
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
index 0be0a7a572a..45a236ef814 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
@@ -15,6 +15,7 @@
 package org.tensorflow.framework.metrics.impl;
 
 import org.tensorflow.Operand;
+import org.tensorflow.Session;
 import org.tensorflow.framework.losses.impl.LossTuple;
 import org.tensorflow.framework.losses.impl.LossesHelper;
 import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException;
@@ -26,10 +27,7 @@
 import org.tensorflow.op.core.*;
 import org.tensorflow.op.math.Mean;
 import org.tensorflow.op.nn.TopK;
-import org.tensorflow.types.TBool;
-import org.tensorflow.types.TFloat64;
-import org.tensorflow.types.TInt32;
-import org.tensorflow.types.TInt64;
+import org.tensorflow.types.*;
 import org.tensorflow.types.family.TIntegral;
 import org.tensorflow.types.family.TNumber;
 
@@ -277,6 +275,7 @@ public static List assertShapes(
     return updateOperations;
   }
 
+
   /**
    * Returns an op to update the given confusion matrix variables.
    *
@@ -335,7 +334,7 @@ public static  List updateConfusionMatrixVariables(
       Map> varInitializers,
       Operand labels,
       Operand predictions,
-      float[] thresholds,
+      Operand thresholds,
       Integer topK,
       Integer classId,
       Operand sampleWeight,
@@ -349,68 +348,65 @@ public static  List updateConfusionMatrixVariables(
       return Collections.EMPTY_LIST;
     }
 
-    Operand lLabels = labels;
-    Operand lPredictions = predictions;
-    Operand lSampleWeight = sampleWeight;
+    Operand tLabels = labels;
+    Operand tPredictions = predictions;
+    Operand tSampleWeight = sampleWeight;
 
-    Operand numThresholds;
+    Operand numThresholds =
+        tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar()));
     Operand oneThresh;
     if (multiLabel) {
-      numThresholds = tf.shape.size(lLabels, tf.constant(0));
-      oneThresh = tf.math.equal(tf.constant(1), tf.constant(thresholds.length));
+      oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds));
     } else {
       // TODO handle Ragged Tensors????
       // [y_pred,
       //    y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
       //                                                   sampleWeights)
-      numThresholds = tf.constant(thresholds.length);
       oneThresh = tf.constant(true);
+      numThresholds = tf.shape.size(tf.shape(thresholds));
     }
 
     List controlOps = new ArrayList<>();
-    Operand axes = allAxes(tf, lPredictions);
+    Operand axes = allAxes(tf, tPredictions);
     controlOps.add(
         tf.withSubScope("updateConfusionMatrixVariables-1")
             .assertThat(
                 tf.reduceAll(
                     tf.math.greaterEqual(
-                        lPredictions, cast(tf, tf.constant(0), lPredictions.type())),
+                        tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
                     axes),
                 Collections.singletonList(tf.constant("predictions must be >= 0"))));
     controlOps.add(
         tf.withSubScope("updateConfusionMatrixVariables-2")
             .assertThat(
                 tf.reduceAll(
-                    tf.math.lessEqual(lPredictions, cast(tf, tf.constant(1), lPredictions.type())),
+                    tf.math.lessEqual(tPredictions, cast(tf, tf.constant(1), tPredictions.type())),
                     axes),
                 Collections.singletonList(tf.constant("predictions must be <= 1"))));
 
     LossTuple result =
-        LossesHelper.squeezeOrExpandDimensions(tf, lLabels, lPredictions, lSampleWeight);
-    lPredictions = result.getTarget();
-    lLabels = result.getLabels();
-    lSampleWeight = result.getSampleWeights();
+        LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions, tSampleWeight);
+    tPredictions = result.getTarget();
+    tLabels = result.getLabels();
+    tSampleWeight = result.getSampleWeights();
 
-    if (!lPredictions.shape().isCompatibleWith(lLabels.shape()))
+    if (!tPredictions.shape().isCompatibleWith(tLabels.shape()))
       throw new IllegalArgumentException(
           String.format(
               "Shapes %s and %s are incompatible)",
-              lPredictions.shape().toString(), lLabels.asOutput().shape().toString()));
+              tPredictions.shape().toString(), tLabels.asOutput().shape().toString()));
 
     if (topK != null) {
-      lPredictions = filterTopK(tf, lPredictions, topK);
+      tPredictions = filterTopK(tf, tPredictions, topK);
     }
 
     if (classId != null) {
-      lLabels = tf.squeeze(tf.gather(lLabels, tf.constant(new int[] {classId}), tf.constant(1)));
-      lPredictions =
-          tf.squeeze(tf.gather(lPredictions, tf.constant(new int[] {classId}), tf.constant(1)));
-      lLabels = tf.expandDims(lLabels, tf.constant(0));
-      lPredictions = tf.expandDims(lPredictions, tf.constant(0));
+      tLabels = tf.gather(tLabels, tf.constant(new int[] {classId}), tf.constant(1));
+      tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classId}), tf.constant(1));
     }
-    org.tensorflow.op.core.Shape predShape = tf.shape(lPredictions);
+    org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions);
     Operand numPredictions =
-        tf.reshape(tf.shape.size(lPredictions, tf.constant(0)), tf.constant(Shape.scalar()));
+        tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar()));
     Operand numLabels =
         tf.select(
             tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)),
@@ -424,50 +420,52 @@ lPredictions, cast(tf, tf.constant(0), lPredictions.type())),
     Operand predictionsExtraDim;
     Operand labelsExtraDim;
     if (multiLabel) {
-      predictionsExtraDim = tf.expandDims(lPredictions, tf.constant(0));
-      labelsExtraDim = tf.expandDims(cast(tf, lLabels, TBool.class), tf.constant(0));
+      predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0));
+      labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0));
     } else {
-      predictionsExtraDim = tf.reshape(lPredictions, tf.constant(Shape.of(1, -1)));
-      labelsExtraDim = tf.reshape(cast(tf, lLabels, TBool.class), tf.constant(Shape.of(1, -1)));
+      predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1)));
+      labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1)));
     }
     List> threshPretileShape;
     List> threshTiles;
     List> dataTiles;
     if (multiLabel) {
       threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1));
-
       threshTiles = Arrays.asList(tf.constant(1), numPredictions, threshLabelTile);
       dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1));
     } else {
-      threshPretileShape = Arrays.asList(numThresholds, tf.constant(-1));
+      threshPretileShape =
+          Arrays.asList(tf.reshape(numThresholds, tf.constant(Shape.scalar())), tf.constant(-1));
       Operand mul = tf.math.mul(numPredictions, numLabels);
       threshTiles = Arrays.asList(tf.constant(1), mul);
       dataTiles = Arrays.asList(numThresholds, tf.constant(1));
     }
 
     Operand thresholdsReshaped =
-        tf.reshape(
-            cast(tf, tf.constant(thresholds), predictions.type()), tf.stack(threshPretileShape));
+        tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape));
     Operand threshTilesShape = tf.stack(threshTiles);
     Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape);
-    Operand predsTiled = tf.tile(predictionsExtraDim, tf.stack(dataTiles));
+    Operand stackedTiles = tf.stack(dataTiles);
+
+    Operand predsTiled = tf.tile(predictionsExtraDim, stackedTiles);
 
     // Compare predictions and threshold.
     Operand predIsPos = tf.math.greater(predsTiled, threshTiled);
     // Tile labels by number of thresholds
     Operand labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles));
     Operand weightsTiled;
-    if (lSampleWeight != null) {
-      lSampleWeight =
-          tf.broadcastTo(cast(tf, lSampleWeight, predictions.type()), tf.shape(lPredictions));
-      weightsTiled = tf.tile(tf.reshape(lSampleWeight, tf.stack(threshTiles)), tf.stack(dataTiles));
+    if (tSampleWeight != null) {
+      tSampleWeight =
+          tf.broadcastTo(cast(tf, tSampleWeight, predictions.type()), tf.shape(tPredictions));
+      weightsTiled = tf.tile(tf.reshape(tSampleWeight, tf.stack(threshTiles)), tf.stack(dataTiles));
     } else {
       weightsTiled = null;
     }
 
     if (labelWeights != null) {
       Operand lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0));
-      lLabelWeights = tf.broadcastTo(cast(tf, lLabelWeights, labelWeights.type()), lPredictions);
+
+      lLabelWeights = tf.broadcastTo(lLabelWeights, tPredictions);
       Operand labelWeightsTiled =
           tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles));
       if (weightsTiled == null) {
@@ -520,6 +518,7 @@ lPredictions, cast(tf, tf.constant(0), lPredictions.type())),
     return controlOps;
   }
 
+
   /**
    * Creates an Operand that adds the values by taking the logical and of labels and predictions to
    * the specified confusion matrix variable.
@@ -700,57 +699,57 @@ public static  Operand confusionMatrix(
               predictions.shape().toString(), labels.shape().toString()));
     tf = tf.withSubScope("confusionMatrix");
     LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, predictions, labels, null);
-    Operand lPredictions = cast(tf, ops.getTarget(), TInt64.class);
-    Operand lLabels = cast(tf, ops.getLabels(), TInt64.class);
+    Operand tPredictions = cast(tf, ops.getTarget(), TInt64.class);
+    Operand tLabels = cast(tf, ops.getLabels(), TInt64.class);
 
     List labelControls = new ArrayList<>();
     List predictionControls = new ArrayList<>();
 
     labelControls.add(
         tf.assertThat(
-            tf.reduceAny(tf.math.greaterEqual(lLabels, tf.constant(0L)), allAxes(tf, lLabels)),
+            tf.reduceAny(tf.math.greaterEqual(tLabels, tf.constant(0L)), allAxes(tf, tLabels)),
             Collections.singletonList(tf.constant("`labels` contains negative values"))));
 
     predictionControls.add(
         tf.assertThat(
             tf.reduceAny(
-                tf.math.greaterEqual(lPredictions, tf.constant(0L)), allAxes(tf, lPredictions)),
+                tf.math.greaterEqual(tPredictions, tf.constant(0L)), allAxes(tf, tPredictions)),
             Collections.singletonList(tf.constant("`predictions` contains negative values"))));
     if (numClasses == null) {
       numClasses =
           tf.math.maximum(
-              tf.reduceMax(lPredictions, allAxes(tf, lPredictions)),
-              tf.reduceMax(lLabels, allAxes(tf, lLabels)));
+              tf.reduceMax(tPredictions, allAxes(tf, tPredictions)),
+              tf.reduceMax(tLabels, allAxes(tf, tLabels)));
     } else {
       labelControls.add(
           tf.assertThat(
-              tf.reduceAny(tf.math.less(lLabels, numClasses), allAxes(tf, lLabels)),
+              tf.reduceAny(tf.math.less(tLabels, numClasses), allAxes(tf, tLabels)),
               Collections.singletonList(tf.constant("``labels` out of bounds"))));
       predictionControls.add(
           tf.assertThat(
-              tf.reduceAny(tf.math.less(lPredictions, numClasses), allAxes(tf, lPredictions)),
+              tf.reduceAny(tf.math.less(tPredictions, numClasses), allAxes(tf, tPredictions)),
               Collections.singletonList(tf.constant("``predictions` out of bounds"))));
     }
 
     if (weights != null) {
-      if (!lPredictions.shape().isCompatibleWith(weights.shape())) {
+      if (!tPredictions.shape().isCompatibleWith(weights.shape())) {
         throw new IllegalArgumentException(
             String.format(
                 "Prediction shape %s is not compatible with weights shape %s",
-                lPredictions.shape().toString(), weights.shape().toString()));
+                tPredictions.shape().toString(), weights.shape().toString()));
       }
     }
 
     Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls);
-    lLabels = tfc.identity(lLabels);
+    tLabels = tfc.identity(tLabels);
 
     tfc = tf.withSubScope("confusionMatrixPredicitons").withControlDependencies(predictionControls);
-    lPredictions = tfc.identity(lPredictions);
+    tPredictions = tfc.identity(tPredictions);
 
     Operand shape = tf.stack(Arrays.asList(numClasses, numClasses));
-    Operand indices = tf.stack(Arrays.asList(lLabels, lPredictions), Stack.axis(1L));
+    Operand indices = tf.stack(Arrays.asList(tLabels, tPredictions), Stack.axis(1L));
     Operand values =
-        weights == null ? cast(tf, tf.onesLike(lPredictions), type) : cast(tf, weights, type);
+        weights == null ? cast(tf, tf.onesLike(tPredictions), type) : cast(tf, weights, type);
     SparseTensor cmSparse = new SparseTensor<>(indices, values, shape);
     Operand zeroMatrix = tf.zeros(shape, type);
 
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java
index 377124333bd..84898d8a4d3 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java
@@ -143,10 +143,10 @@ public List updateStateList(
     Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
 
     Map> confusionMatrix = new HashMap<>();
-    confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.getTruePositives());
-    confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.getFalsePositives());
-    confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.getTrueNegatives());
-    confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.getFalseNegatives());
+    confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, getTruePositives());
+    confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, getFalsePositives());
+    confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, getTrueNegatives());
+    confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, getFalseNegatives());
 
     return MetricsHelper.updateConfusionMatrixVariables(
         tf,
@@ -154,7 +154,7 @@ public List updateStateList(
         Collections.EMPTY_MAP,
         tLabels,
         tPredictions,
-        this.getThresholds(),
+        tf.constant(thresholds),
         null,
         null,
         tSampleWeights,
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java
index 35962a568ca..148ca520d3f 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java
@@ -16,6 +16,7 @@
 
 import org.junit.jupiter.api.Test;
 import org.tensorflow.Operand;
+import org.tensorflow.framework.metrics.impl.MetricsHelper;
 import org.tensorflow.framework.utils.TestSession;
 import org.tensorflow.ndarray.Shape;
 import org.tensorflow.op.Op;
@@ -125,7 +126,6 @@ public void testDivByZero() {
       Op update = instance.updateState(labels, predictions, null);
       session.run(update);
       Operand precision = instance.result();
-
       session.evaluate(0, precision);
     }
   }

From 616ebb2c2a9a5e6e5b4686f9d63bba545c8e3002 Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Tue, 2 Mar 2021 12:01:28 -0500
Subject: [PATCH 09/97] change classId to classIndex

Added comment on Operand numThresholds reshape to scalar.

Added comment to ExtraDims
---
 .../framework/metrics/impl/MetricsHelper.java      | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
index 45a236ef814..302997eb51a 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
@@ -309,7 +309,8 @@ public static List assertShapes(
    *     topK is set)
    * @param topK Optional, indicates that the positive labels should be limited to the top k
    *     predictions, may be null.
-   * @param classId Optional, limits the prediction and labels to the specified class
+   * @param classIndex Optional, limits the prediction and labels to the specified class.
+   *                The classIndex is and integer representing a specific classification class's input data..
    * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as
    *     labels, and must be broadcast to labels (i.e., all dimensions
    *     must be either 1, or the same as the corresponding labels
@@ -336,7 +337,7 @@ public static  List updateConfusionMatrixVariables(
       Operand predictions,
       Operand thresholds,
       Integer topK,
-      Integer classId,
+      Integer classIndex,
       Operand sampleWeight,
       boolean multiLabel,
       Operand labelWeights) {
@@ -352,6 +353,7 @@ public static  List updateConfusionMatrixVariables(
     Operand tPredictions = predictions;
     Operand tSampleWeight = sampleWeight;
 
+    // reshape to scalar for operations later.
     Operand numThresholds =
         tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar()));
     Operand oneThresh;
@@ -363,7 +365,6 @@ public static  List updateConfusionMatrixVariables(
       //    y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
       //                                                   sampleWeights)
       oneThresh = tf.constant(true);
-      numThresholds = tf.shape.size(tf.shape(thresholds));
     }
 
     List controlOps = new ArrayList<>();
@@ -400,9 +401,9 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
       tPredictions = filterTopK(tf, tPredictions, topK);
     }
 
-    if (classId != null) {
-      tLabels = tf.gather(tLabels, tf.constant(new int[] {classId}), tf.constant(1));
-      tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classId}), tf.constant(1));
+    if (classIndex != null) {
+      tLabels = tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(1));
+      tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1));
     }
     org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions);
     Operand numPredictions =
@@ -417,6 +418,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
                 tf.constant(0)));
     Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1));
 
+    // The ExtraDims are added so the operands of the tile operations later on are compatible.
     Operand predictionsExtraDim;
     Operand labelsExtraDim;
     if (multiLabel) {

From af69e0068bfedcd323c14e09072acd446eb565d7 Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Tue, 2 Mar 2021 12:02:32 -0500
Subject: [PATCH 10/97] fix spurious "this.".

---
 .../org/tensorflow/framework/metrics/AUC.java | 78 +++++++++----------
 1 file changed, 39 insertions(+), 39 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
index 8a31dfd3fce..1269f3453f3 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
@@ -613,7 +613,7 @@ private Map> build(Shape shape) {
     }
     Ops tf = getTF();
 
-    if (this.isMultiLabel()) {
+    if (isMultiLabel()) {
       if (shape == null) {
         throw new IllegalArgumentException("For multiLabel, a shape must be provided");
       }
@@ -622,14 +622,14 @@ private Map> build(Shape shape) {
             String.format(
                 "labels must have rank=2 when multiLabel is true. Found rank %d.",
                 shape.numDimensions()));
-      this.numLabels = (int) shape.size(1);
-      variableShape = Shape.of(this.numThresholds, this.numLabels);
+      numLabels = (int) shape.size(1);
+      variableShape = Shape.of(numThresholds, numLabels);
     } else {
-      variableShape = Shape.of(this.numThresholds);
+      variableShape = Shape.of(numThresholds);
     }
 
     // Create metric variables
-    Zeros zeros = new Zeros<>(getTF());
+    Zeros zeros = new Zeros<>(tf);
     Operand zero = zeros.call(tf.constant(variableShape), type);
     if (truePositives == null) {
       truePositives = tf.withName(getTruePositivesName()).variable(zero);
@@ -651,7 +651,7 @@ private Map> build(Shape shape) {
       initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, tf.assign(falseNegatives, zero));
     }
 
-    this.initialized = true;
+    initialized = true;
     return initializers;
   }
 
@@ -668,39 +668,39 @@ public List updateStateList(
     Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
     List updateOperations = new ArrayList<>();
     Map> varInitializers = Collections.EMPTY_MAP;
-    if (!this.initialized) {
+    if (!initialized) {
       varInitializers = build(tPredictions.shape());
     }
-    if (this.isMultiLabel() || this.getLabelWeights() != null) {
+    if (isMultiLabel() || getLabelWeights() != null) {
       // labels should have shape (number of examples, number of labels).
       List> symbols = new ArrayList<>();
       symbols.add(new SymbolicShape<>(tLabels, "N", "L"));
-      if (this.isMultiLabel()) {
+      if (isMultiLabel()) {
         // TP, TN, FP, and FN should all have shape
         //(number of thresholds, number of labels).
-        symbols.add(new SymbolicShape<>(this.truePositives, "T", "L"));
-        symbols.add(new SymbolicShape<>(this.falsePositives, "T", "L"));
-        symbols.add(new SymbolicShape<>(this.trueNegatives, "T", "L"));
-        symbols.add(new SymbolicShape<>(this.falseNegatives, "T", "L"));
+        symbols.add(new SymbolicShape<>(truePositives, "T", "L"));
+        symbols.add(new SymbolicShape<>(falsePositives, "T", "L"));
+        symbols.add(new SymbolicShape<>(trueNegatives, "T", "L"));
+        symbols.add(new SymbolicShape<>(falseNegatives, "T", "L"));
       }
-      if (this.getLabelWeights() != null) {
-        symbols.add(new SymbolicShape<>(this.getLabelWeights(), "L", ""));
+      if (getLabelWeights() != null) {
+        symbols.add(new SymbolicShape<>(getLabelWeights(), "L", ""));
       }
       updateOperations.addAll(
           MetricsHelper.assertShapes(tf, symbols, "Number of labels is not consistent."));
     }
 
     Map> confusionMatrix = new HashMap<>();
-    confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives);
-    confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.falsePositives);
-    confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.trueNegatives);
-    confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives);
+    confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives);
+    confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, falsePositives);
+    confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, trueNegatives);
+    confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, falseNegatives);
 
     // Only forward labelWeights to update_confusion_matrix_variables when
     // multiLabel is false. Otherwise the averaging of individual label AUCs is
     // handled in AUC.result
-    if (this.isMultiLabel()) {
-      this.labelWeights = null;
+    if (isMultiLabel()) {
+      labelWeights = null;
     }
     updateOperations.addAll(
         MetricsHelper.updateConfusionMatrixVariables(
@@ -730,7 +730,7 @@ private Operand interpolatePRAuc() {
         tf.slice(
             truePositives,
             tf.constant(new int[] {0}),
-            tf.constant(new int[] {this.getNumThresholds() - 1}));
+            tf.constant(new int[] {getNumThresholds() - 1}));
     // truePositives[1:]
     Operand tp1 =
         tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
@@ -744,7 +744,7 @@ private Operand interpolatePRAuc() {
             tf.slice(
                 p,
                 tf.constant(new int[] {0}),
-                tf.constant(new int[] {this.getNumThresholds() - 1})),
+                tf.constant(new int[] {getNumThresholds() - 1})),
             tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})));
 
     Operand precisionSlope =
@@ -764,7 +764,7 @@ private Operand interpolatePRAuc() {
                     tf.slice(
                         p,
                         tf.constant(new int[] {0}),
-                        tf.constant(new int[] {this.getNumThresholds() - 1})),
+                        tf.constant(new int[] {getNumThresholds() - 1})),
                     tf.dtypes.cast(tf.constant(0), p.type())),
                 tf.math.greater(
                     tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
@@ -773,7 +773,7 @@ private Operand interpolatePRAuc() {
                 tf.slice(
                     p,
                     tf.constant(new int[] {0}),
-                    tf.constant(new int[] {this.getNumThresholds() - 1})),
+                    tf.constant(new int[] {getNumThresholds() - 1})),
                 tf.math.maximum(
                     tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
                     tf.dtypes.cast(tf.constant(0), p.type()))),
@@ -790,15 +790,15 @@ private Operand interpolatePRAuc() {
         tf.math.divNoNan(
             aucTotalPos,
             tf.math.maximum(
-                tf.math.add(tp1, fn1), tf.dtypes.cast(tf.constant(0), this.truePositives.type())));
+                tf.math.add(tp1, fn1), tf.dtypes.cast(tf.constant(0), truePositives.type())));
 
-    if (this.isMultiLabel()) {
+    if (isMultiLabel()) {
       Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0));
-      if (this.getLabelWeights() == null) {
+      if (getLabelWeights() == null) {
         return MetricsHelper.mean(tf, byLabelAuc);
       } else {
         return tf.math.divNoNan(
-            tf.reduceSum(tf.math.mul(byLabelAuc, this.getLabelWeights()), allAxes(tf, byLabelAuc)),
+            tf.reduceSum(tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, byLabelAuc)),
             tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights())));
       }
     } else {
@@ -810,17 +810,17 @@ private Operand interpolatePRAuc() {
   @Override
   public Operand result() {
 
-    if (this.getCurve() == AUCCurve.PR
-        && this.getSummationMethod() == AUCSummationMethod.INTERPOLATION) {
+    if (getCurve() == AUCCurve.PR
+        && getSummationMethod() == AUCSummationMethod.INTERPOLATION) {
       // This use case is different and is handled separately.
-      return this.interpolatePRAuc();
+      return interpolatePRAuc();
     }
     Ops tf = getTF();
     Operand x;
     Operand y;
     Operand recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives));
 
-    if (this.getCurve() == AUCCurve.ROC) {
+    if (getCurve() == AUCCurve.ROC) {
       x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives));
       y = recall;
     } else { // AUCCurve.PR
@@ -832,12 +832,12 @@ public Operand result() {
     // y[:self.numThresholds - 1]
     Operand ySlice1 =
         tf.slice(
-            y, tf.constant(new int[] {0}), tf.constant(new int[] {this.getNumThresholds() - 1}));
+            y, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1}));
     // y[1:]
     Operand ySlice2 = tf.slice(y, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
 
     Operand heights = null;
-    switch (this.getSummationMethod()) {
+    switch (getSummationMethod()) {
       case INTERPOLATION:
         heights =
             tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type()));
@@ -850,19 +850,19 @@ public Operand result() {
         break;
     }
 
-    if (this.isMultiLabel()) {
+    if (isMultiLabel()) {
       Operand riemannTerms =
           tf.math.mul(
               tf.math.sub(
                   tf.slice(
                       x,
                       tf.constant(new int[] {0}),
-                      tf.constant(new int[] {this.getNumThresholds() - 1})),
+                      tf.constant(new int[] {getNumThresholds() - 1})),
                   tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))),
               heights);
       Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0));
 
-      if (this.getLabelWeights() == null) {
+      if (getLabelWeights() == null) {
         return MetricsHelper.mean(tf, byLabelAuc);
       } else {
         //Weighted average of the label AUCs.
@@ -875,7 +875,7 @@ public Operand result() {
     } else {
       Operand slice1 =
           tf.slice(
-              x, tf.constant(new int[] {0}), tf.constant(new int[] {this.getNumThresholds() - 1}));
+              x, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1}));
       Operand slice2 = tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
       Operand sub = tf.math.sub(slice1, slice2);
       Operand operand = tf.math.mul(sub, heights);

From 05c3d88aec2561ecf69177f40664550870c8b73f Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Wed, 3 Mar 2021 12:31:11 -0500
Subject: [PATCH 11/97] Remove references to keras in javadoc.

---
 .../java/org/tensorflow/framework/metrics/AUC.java    |  2 +-
 .../framework/metrics/SparseCategoricalAccuracy.java  | 11 +----------
 2 files changed, 2 insertions(+), 11 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
index 1269f3453f3..5ac07e98451 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
@@ -58,7 +58,7 @@
  * 

Usage:
* *

- * AUC m = new  tf.keras.metrics.AUC( tf, 3);
+ * AUC m = new  org.tensorflow.framework.metrcis.AUC( tf, 3);
  * m.updateState( tf.constant(new float[] {0, 0, 1,1}),
  *          tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
  *
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java
index 7034861d8d2..0d18c1e2dcb 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java
@@ -44,7 +44,7 @@
  * 

Usage: * *

- * SparseCategoricalAccuracy m = new tf.keras.metrics.SparseCategoricalAccuracy();
+ * SparseCategoricalAccuracy m = new org.tensorflow.framework.metrcis.SparseCategoricalAccuracy();
  * m.update_state(tf.constant(new float[][] {{2}, {1}},
  *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, [{0.05f, 0.95f, 0f}});
  * Operand<TFloat32> result = m.result();
@@ -63,15 +63,6 @@
  * 0.3
  * 
* - *

Usage with tf.keras API: - * - *

- * Model model = new tf.keras. models.Model(inputs, outputs);
- * model.compile(
- *     "sgd",
- *     loss="mse",
- *     metrics=["sparse_categorical_accuracy"]);
- * 
* * @param The data type for the metric result */ From dcfaa0fb2a2cb9071d2a9edbceccd49a8111db49 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Mar 2021 12:51:53 -0500 Subject: [PATCH 12/97] Fix javadoc --- .../java/org/tensorflow/framework/constraints/Constraint.java | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index d3094b5e9e9..306361959bf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -42,6 +42,7 @@ public Constraint(Ops tf) { * * @param weights the weights * @return the constrained weights + * @param the data type for weights and results. */ public abstract Operand call(Operand weights); From de9bc10b9d90e89712fb6f7a1f84f6095588d564 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Mar 2021 13:27:32 -0500 Subject: [PATCH 13/97] Reformat code and fix labelWeights argument in call to updateConfusionMatrixVariables --- .../org/tensorflow/framework/metrics/AUC.java | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 5ac07e98451..cd83bbeb26d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -209,7 +209,7 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) { this( tf, null, - DEFAULT_NUM_THRESHOLDS, + DEFAULT_NUM_THRESHOLDS, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, thresholds, @@ -264,7 +264,7 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { this( tf, name, - DEFAULT_NUM_THRESHOLDS, + DEFAULT_NUM_THRESHOLDS, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, thresholds, @@ -322,7 +322,7 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C this( tf, name, - DEFAULT_NUM_THRESHOLDS, + DEFAULT_NUM_THRESHOLDS, curve, AUCSummationMethod.INTERPOLATION, thresholds, @@ -378,7 +378,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) this( tf, null, - DEFAULT_NUM_THRESHOLDS, + DEFAULT_NUM_THRESHOLDS, curve, AUCSummationMethod.INTERPOLATION, thresholds, @@ -435,7 +435,17 @@ public AUC( AUCSummationMethod summationMethod, long seed, Class type) { - this(tf, null, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type); + this( + tf, + null, + DEFAULT_NUM_THRESHOLDS, + curve, + summationMethod, + thresholds, + false, + null, + seed, + type); } /** @@ -487,7 +497,17 @@ public AUC( AUCSummationMethod summationMethod, long seed, Class type) { - this(tf, name, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type); + this( + tf, + name, + DEFAULT_NUM_THRESHOLDS, + curve, + summationMethod, + thresholds, + false, + null, + seed, + type); } /** @@ -677,7 +697,7 @@ public List updateStateList( symbols.add(new SymbolicShape<>(tLabels, "N", "L")); if (isMultiLabel()) { // TP, TN, FP, and FN should all have shape - //(number of thresholds, number of labels). + // (number of thresholds, number of labels). symbols.add(new SymbolicShape<>(truePositives, "T", "L")); symbols.add(new SymbolicShape<>(falsePositives, "T", "L")); symbols.add(new SymbolicShape<>(trueNegatives, "T", "L")); @@ -699,9 +719,6 @@ public List updateStateList( // Only forward labelWeights to update_confusion_matrix_variables when // multiLabel is false. Otherwise the averaging of individual label AUCs is // handled in AUC.result - if (isMultiLabel()) { - labelWeights = null; - } updateOperations.addAll( MetricsHelper.updateConfusionMatrixVariables( tf, @@ -714,7 +731,7 @@ public List updateStateList( null, tSampleWeights, isMultiLabel(), - getLabelWeights())); + isMultiLabel() ? null : getLabelWeights())); return updateOperations; } @@ -742,9 +759,7 @@ private Operand interpolatePRAuc() { Operand dP = tf.math.sub( tf.slice( - p, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})), + p, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})), tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))); Operand precisionSlope = @@ -771,9 +786,7 @@ private Operand interpolatePRAuc() { tf.dtypes.cast(tf.constant(0), p.type()))), tf.math.divNoNan( tf.slice( - p, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})), + p, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})), tf.math.maximum( tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})), tf.dtypes.cast(tf.constant(0), p.type()))), @@ -810,8 +823,7 @@ private Operand interpolatePRAuc() { @Override public Operand result() { - if (getCurve() == AUCCurve.PR - && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { + if (getCurve() == AUCCurve.PR && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { // This use case is different and is handled separately. return interpolatePRAuc(); } @@ -831,16 +843,14 @@ && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { // Find the rectangle heights based on `summationMethod`. // y[:self.numThresholds - 1] Operand ySlice1 = - tf.slice( - y, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); + tf.slice(y, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); // y[1:] Operand ySlice2 = tf.slice(y, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); Operand heights = null; switch (getSummationMethod()) { case INTERPOLATION: - heights = - tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type())); + heights = tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type())); break; case MINORING: heights = tf.math.minimum(ySlice1, ySlice2); @@ -865,7 +875,7 @@ && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { if (getLabelWeights() == null) { return MetricsHelper.mean(tf, byLabelAuc); } else { - //Weighted average of the label AUCs. + // Weighted average of the label AUCs. return tf.math.divNoNan( tf.reduceSum( tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())), @@ -874,8 +884,7 @@ && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { } else { Operand slice1 = - tf.slice( - x, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); + tf.slice(x, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); Operand slice2 = tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); Operand sub = tf.math.sub(slice1, slice2); Operand operand = tf.math.mul(sub, heights); From 4520294009e8220e015971c6c9f7f21073a1b6c0 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Mar 2021 13:28:50 -0500 Subject: [PATCH 14/97] Reformat code add code comments and change update_xx (update_fn) to updateXX (updateFN) to eliminate snake case. --- .../framework/metrics/impl/MetricsHelper.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 302997eb51a..3d4a2c8dc4f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -413,6 +413,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), tf.constant(1), tf.reduceProd( + // take all but the first dimension tf.shape.takeLast( predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); @@ -479,21 +480,21 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Map loopVars = new HashMap<>(); loopVars.put(ConfusionMatrixEnum.TRUE_POSITIVES, new Operand[] {labelIsPos, predIsPos}); - Variable update_tn = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES); - Variable update_fp = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES); - Variable update_fn = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES); + Variable updateTN = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES); + Variable updateFP = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES); + Variable updateFN = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES); Operand predIsNeg = null; Operand labelIsNeg; - if (update_fn != null || update_tn != null) { + if (updateFN != null || updateTN != null) { predIsNeg = tf.math.logicalNot(predIsPos); loopVars.put(ConfusionMatrixEnum.FALSE_NEGATIVES, new Operand[] {labelIsPos, predIsNeg}); } - if (update_fp != null || update_tn != null) { + if (updateFP != null || updateTN != null) { labelIsNeg = tf.math.logicalNot(labelIsPos); loopVars.put(ConfusionMatrixEnum.FALSE_POSITIVES, new Operand[] {labelIsNeg, predIsPos}); - if (update_tn != null) { + if (updateTN != null) { loopVars.put(ConfusionMatrixEnum.TRUE_NEGATIVES, new Operand[] {labelIsNeg, predIsNeg}); } } From a0b70415d8dfd34be39695b187c86b4a51e8c352 Mon Sep 17 00:00:00 2001 From: deansher Date: Wed, 3 Mar 2021 08:39:28 -0500 Subject: [PATCH 15/97] Added javadocs and internal docs to AUC.java and MetricsHelper.java --- .../org/tensorflow/framework/metrics/AUC.java | 65 ++++++++++++++++--- .../framework/metrics/impl/MetricsHelper.java | 19 ++++-- 2 files changed, 71 insertions(+), 13 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index cd83bbeb26d..cae67dbd4f0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -106,12 +106,46 @@ public class AUC extends Metric { private final String falseNegativesName; private final Map> initializers = new HashMap<>(); private final Class type; + + /** + * The size of the label dimension. + */ private Integer numLabels; + private Operand labelWeights; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ private Variable truePositives; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ private Variable falsePositives; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ private Variable trueNegatives; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ private Variable falseNegatives; + private boolean initialized; /** @@ -515,22 +549,24 @@ public AUC( * * @param tf The TensorFlow Ops * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. - * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values - * must be > 1. + * @param numThresholds the number of thresholds to use when discretizing the roc curve. + * This includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link * AUCCurve#PR} for the Precision-Recall-curve. * @param summationMethod Specifies the Riemann summation method used * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, - * the numThresholds parameter is ignored. Values should be in [0, 1]. + * the numThresholds parameter is ignored. Values should be in [0, 1]. This method + * automatically brackets the provided thresholds with a (-{@link #EPSILON}) + * below and a (1 + {@link #EPSILON}) above. * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein * AUC is computed separately for each label and then averaged across labels, or (when false) * if the data should be flattened into a single label before AUC computation. In the latter * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an * individual data point. Should be set to false for multi-class data. * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When - * multi_label is True, the weights are applied to the individual label AUCs when they are - * averaged to produce the multi-label AUC. When it's false, they are used to weight the - * individual label predictions in computing the confusion matrix on the flattened data. + * multiLabel is true, the weights are applied to the individual label AUCs when + * they are averaged to produce the multi-label AUC. When it's false, they are used to weight + * the individual label predictions in computing the confusion matrix on the flattened data. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. @@ -595,7 +631,7 @@ public AUC( System.arraycopy(thresholds, 0, this.thresholds, 1, thresholds.length); this.thresholds[this.numThresholds - 1] = 1 + EPSILON; - // # Handle multilabel arguments. + // Handle multilabel arguments. if (labelWeights != null) { // assert that labelWeights are non-negative. @@ -675,7 +711,20 @@ private Map> build(Shape shape) { return initializers; } - /** {@inheritDoc} */ + /** + * Creates a List of Operations to update the metric state based on labels and predictions. + * + * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more + * class dimensions, and L1 is a potential extra dimension of size 1 that + * would be squeezed. Will be cast to T. If + * {@link #multiLabel} or if {@link #labelWeights} != null, + * then Cx must be a single dimension. + * @param predictions the predictions shape (N, Cx, P1?). Will be cast to T. + * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to + * T. + * + * @return a List of Operations to update the metric state + */ @Override @SuppressWarnings("unchecked") public List updateStateList( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 3d4a2c8dc4f..f38d0896a5f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -299,18 +299,25 @@ public static List assertShapes( * * @param tf the TensorFlow Ops * @param variablesToUpdate Map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. + * corresponding variables to update as values. If multiLabel is + * false then all shapes are (T), where T is the number of thresholds. If + * multiLabel is true then all shapes are (T, C0), where C0 is the number + * of classes. * @param varInitializers Map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to initializer the corresponding variables from * variablesToUpdate. * @param labels the labels, will be cast to {@link TBool} - * @param predictions the predictions whose values are in the range [0, 1]. + * shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more + * class dimensions, and L1 is a potential extra dimension of size 1 that + * would be squeezed. If multiLabel or if + * labelWeights != null, then Cx must be a single dimension. + * @param predictions the predictions shape (N, Cx, P1?). * @param thresholds thresholds in `the range [0, 1], or {@link #NEG_INF} (used when * topK is set) - * @param topK Optional, indicates that the positive labels should be limited to the top k - * predictions, may be null. + * @param topK Optional, used only if multiLabel, indicates that only the top k + * predictions should be considered. May be null. * @param classIndex Optional, limits the prediction and labels to the specified class. - * The classIndex is and integer representing a specific classification class's input data.. + * The classIndex is an integer index into the first dimension of Cx. * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as * labels, and must be broadcast to labels (i.e., all dimensions * must be either 1, or the same as the corresponding labels @@ -356,6 +363,8 @@ public static List updateConfusionMatrixVariables( // reshape to scalar for operations later. Operand numThresholds = tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); + + // true if we will process thresholds as one-dimensional (possibly because we flatten them) Operand oneThresh; if (multiLabel) { oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds)); From 4dce2cf6b691f1c25185a331e9309da588a297b2 Mon Sep 17 00:00:00 2001 From: deansher Date: Fri, 5 Mar 2021 07:52:07 -0500 Subject: [PATCH 16/97] Added internal docs to MetricsHelper.java --- .../framework/metrics/impl/MetricsHelper.java | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index f38d0896a5f..b57ab821b4e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -325,7 +325,7 @@ public static List assertShapes( * @param multiLabel indicates whether multidimensional prediction/labels should be treated as * multilabel responses, or flattened into a single label. When true, the values of * variablesToUpdate must have a second dimension equal to the number of labels and - * predictions, and those tensors must not be RaggedTensors. + * predictions per example, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. @@ -429,8 +429,14 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); // The ExtraDims are added so the operands of the tile operations later on are compatible. + + // if multilabel, then shape (1, N, D0) + // else shape (1, ND), + // where Dx == Cx except that D0 == 1 if classIndex != null + // ND is the product of N and all Dx Operand predictionsExtraDim; Operand labelsExtraDim; + if (multiLabel) { predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0)); labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0)); @@ -438,9 +444,22 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1))); labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1))); } + + // the shape of each thresholds tile + // if multilabel, then [T, 1, -1] + // else [T, 1], where T is numThresholds List> threshPretileShape; + + // the tiling multiples for thresholds + // if multilabel, then [1, N, threshLabelTile] + // else [1, ND], where ND is the product of N and all Dx List> threshTiles; + + // the tiling multiples for predictionsExtraDim + // If multilabel, then [T, 1, 1] + // else [T, 1] List> dataTiles; + if (multiLabel) { threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); threshTiles = Arrays.asList(tf.constant(1), numPredictions, threshLabelTile); @@ -456,9 +475,15 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand thresholdsReshaped = tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); Operand threshTilesShape = tf.stack(threshTiles); + + // if multilabel, then shape (T, N, threshLabelTile) + // else shape (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); + Operand stackedTiles = tf.stack(dataTiles); + // if multilabel, then shape (T, N, D0) + // else shape (T, ND) Operand predsTiled = tf.tile(predictionsExtraDim, stackedTiles); // Compare predictions and threshold. From 7274cf5f659db3c842d5043398eb9cac94d74a5d Mon Sep 17 00:00:00 2001 From: deansher Date: Sat, 6 Mar 2021 11:47:26 -0500 Subject: [PATCH 17/97] Improved internal docs in MetricsHelper.java --- .../framework/metrics/impl/MetricsHelper.java | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index b57ab821b4e..b4834e28764 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -360,6 +360,7 @@ public static List updateConfusionMatrixVariables( Operand tPredictions = predictions; Operand tSampleWeight = sampleWeight; + // size of the 0th dimension of thresholds // reshape to scalar for operations later. Operand numThresholds = tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); @@ -415,8 +416,12 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1)); } org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); + + // number of examples Operand numPredictions = tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); + + // number of labels (or predictions) per example Operand numLabels = tf.select( tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), @@ -426,14 +431,24 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.shape.takeLast( predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); + + // If we will treat thresholds as one-dimensional (always true as of this writing), + // then threshLabelTile is the number of labels (or predictions) per sample. Else it is 1. Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); - // The ExtraDims are added so the operands of the tile operations later on are compatible. + ///////// + // Tile data for threshold comparisons, which is a cross product of thresholds and + // predictions/labels. + // + // In the multilabel case, we want a data shape of (T, N, D0). + // else (T, ND). + // where T is numThresholds + // Dx == Cx except that D0 == 1 if classIndex != null + // ND is the product of N and all Dx. + // In these comments, we refer to all indices beyond the threshold index as a "data position". // if multilabel, then shape (1, N, D0) // else shape (1, ND), - // where Dx == Cx except that D0 == 1 if classIndex != null - // ND is the product of N and all Dx Operand predictionsExtraDim; Operand labelsExtraDim; @@ -447,17 +462,19 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), // the shape of each thresholds tile // if multilabel, then [T, 1, -1] - // else [T, 1], where T is numThresholds + // else [T, -1] List> threshPretileShape; // the tiling multiples for thresholds - // if multilabel, then [1, N, threshLabelTile] - // else [1, ND], where ND is the product of N and all Dx + // We want to repeat the thresholds for each data position. + // if multilabel, then [1, N, threshLabelTile]. (threshLabelTile is typically numLabels) + // else [1, ND] List> threshTiles; - // the tiling multiples for predictionsExtraDim + // tiling multiples for predictionsExtraDim and labelsExtraDim + // We want to repeat the predictions and labels for each threshold. // If multilabel, then [T, 1, 1] - // else [T, 1] + // else [T, 1] List> dataTiles; if (multiLabel) { @@ -477,13 +494,13 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand threshTilesShape = tf.stack(threshTiles); // if multilabel, then shape (T, N, threshLabelTile) - // else shape (T, ND) + // else (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); Operand stackedTiles = tf.stack(dataTiles); // if multilabel, then shape (T, N, D0) - // else shape (T, ND) + // else (T, ND) Operand predsTiled = tf.tile(predictionsExtraDim, stackedTiles); // Compare predictions and threshold. From 5e8fac6854405027ee6e5ddc1795eccd1ac0e718 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 10 Mar 2021 10:34:24 -0500 Subject: [PATCH 18/97] Cleanup of updateConfusionMatrixVariables with variable name changes and reuse of previously declared/assigned variables. --- .../framework/metrics/impl/MetricsHelper.java | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 3d4a2c8dc4f..3fa602b0a74 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -255,7 +255,7 @@ public static List assertShapes( s -> { Long size = dict.get(s); if (size == null) { - size = symbol.getOperand().asOutput().shape().size((int) ll.get()); + size = symbol.getOperand().shape().size((int) ll.get()); dict.put(s, size); } Op assertion = @@ -395,7 +395,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), throw new IllegalArgumentException( String.format( "Shapes %s and %s are incompatible)", - tPredictions.shape().toString(), tLabels.asOutput().shape().toString())); + tPredictions.shape().toString(), tLabels.shape().toString())); if (topK != null) { tPredictions = filterTopK(tf, tPredictions, topK); @@ -406,7 +406,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1)); } org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); - Operand numPredictions = + Operand numExamples = tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); Operand numLabels = tf.select( @@ -434,12 +434,12 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), List> dataTiles; if (multiLabel) { threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); - threshTiles = Arrays.asList(tf.constant(1), numPredictions, threshLabelTile); + threshTiles = Arrays.asList(tf.constant(1), numExamples, threshLabelTile); dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1)); } else { threshPretileShape = Arrays.asList(tf.reshape(numThresholds, tf.constant(Shape.scalar())), tf.constant(-1)); - Operand mul = tf.math.mul(numPredictions, numLabels); + Operand mul = tf.math.mul(numExamples, numLabels); threshTiles = Arrays.asList(tf.constant(1), mul); dataTiles = Arrays.asList(numThresholds, tf.constant(1)); } @@ -448,9 +448,9 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); Operand threshTilesShape = tf.stack(threshTiles); Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); - Operand stackedTiles = tf.stack(dataTiles); + Operand dataTilesShape = tf.stack(dataTiles); - Operand predsTiled = tf.tile(predictionsExtraDim, stackedTiles); + Operand predsTiled = tf.tile(predictionsExtraDim, dataTilesShape); // Compare predictions and threshold. Operand predIsPos = tf.math.greater(predsTiled, threshTiled); @@ -459,8 +459,8 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand weightsTiled; if (tSampleWeight != null) { tSampleWeight = - tf.broadcastTo(cast(tf, tSampleWeight, predictions.type()), tf.shape(tPredictions)); - weightsTiled = tf.tile(tf.reshape(tSampleWeight, tf.stack(threshTiles)), tf.stack(dataTiles)); + tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); + weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape); } else { weightsTiled = null; } From 9ea129e3c318afc7adc59fc385963e61ad0f384d Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 10 Mar 2021 13:49:30 -0500 Subject: [PATCH 19/97] Reformat code --- .../org/tensorflow/framework/metrics/Metrics.java | 2 -- .../framework/metrics/SparseCategoricalAccuracy.java | 1 - .../framework/metrics/impl/MetricsHelper.java | 12 ++++-------- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 671c967af60..e4cc9c3aa3d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -56,7 +56,6 @@ public static Operand topKCategoricalAccuracy( predictions.type()); } - /** * Computes how often integer targets are in the top K predictions. * @@ -104,5 +103,4 @@ public static Operand sparseTopKCatego tf.nn.inTopK(castPredictions, CastHelper.cast(tf, tLabels, TInt32.class), tf.constant(k)), predictions.type()); } - } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 0d18c1e2dcb..7bfa7fd6ee9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -63,7 +63,6 @@ * 0.3 *
* - * * @param The data type for the metric result */ public class SparseCategoricalAccuracy extends MeanMetricWrapper diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 3fa602b0a74..5f4735c818c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; -import org.tensorflow.Session; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; @@ -275,7 +274,6 @@ public static List assertShapes( return updateOperations; } - /** * Returns an op to update the given confusion matrix variables. * @@ -309,8 +307,8 @@ public static List assertShapes( * topK is set) * @param topK Optional, indicates that the positive labels should be limited to the top k * predictions, may be null. - * @param classIndex Optional, limits the prediction and labels to the specified class. - * The classIndex is and integer representing a specific classification class's input data.. + * @param classIndex Optional, limits the prediction and labels to the specified class. The + * classIndex is and integer representing a specific classification class's input data.. * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as * labels, and must be broadcast to labels (i.e., all dimensions * must be either 1, or the same as the corresponding labels @@ -413,7 +411,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), tf.constant(1), tf.reduceProd( - // take all but the first dimension + // take all but the first dimension tf.shape.takeLast( predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); @@ -458,8 +456,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles)); Operand weightsTiled; if (tSampleWeight != null) { - tSampleWeight = - tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); + tSampleWeight = tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape); } else { weightsTiled = null; @@ -521,7 +518,6 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), return controlOps; } - /** * Creates an Operand that adds the values by taking the logical and of labels and predictions to * the specified confusion matrix variable. From 9cb4cc0ef38ea5bcf2c8b5d9633e6e918d310b3d Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 15:48:53 -0500 Subject: [PATCH 20/97] Fix JavaDoc for enumerations --- .../framework/metrics/AUCSummationMethod.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java index 60687dd9005..3887f687eea 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java @@ -15,18 +15,18 @@ package org.tensorflow.framework.metrics; /** - * Specifies the Riemann summation method used. {@link #INTERPOLATION} (default) applies mid-point - * summation scheme for ROC. For PR-AUC, interpolates (true/false) positives but not the ratio that - * is precision (see Davis & Goadrich 2006 for details); {@link #MINORING} applies left - * summation for increasing intervals and right summation for decreasing intervals; {@link - * #MAJORING} does the opposite. + * Specifies the Riemann summation method used. * * @see Davis & Goadrich. 2006 * @see Riemann summation method */ public enum AUCSummationMethod { + /** Apply mid-point summation scheme for {@link AUCCurve#ROC}. For {@link AUCCurve#PR}, interpolates (true/false) positives but not the ratio that + * is precision */ INTERPOLATION, + /** Apply right summation for increasing intervals and left summation for decreasing intervals */ MAJORING, + /** Apply left summation for increasing intervals and right summation for decreasing intervals */ MINORING; /** From 7503bbcc32658f07f5a5825a031f94ff07fae451 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 15:49:20 -0500 Subject: [PATCH 21/97] Fix JavaDoc to emphasize that this does not inherit from Tensor. --- .../org/tensorflow/framework/utils/SparseTensor.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java index 9dee070eea9..81d658ff3a5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java @@ -15,14 +15,19 @@ package org.tensorflow.framework.utils; import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.op.SparseOps; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** - * This is a helper class that represents a sparse tensor who's attributes may be passed to - * {@link org.tensorflow.op.Ops#sparse} methods. + * This is a helper class that represents a sparse tensor who's attributes may be passed to {@link + * SparseOps} methods. * - * @param the type of the SparseTensor + *

This class does not inherit from {@link Tensor}, but is merely a place to accumulate the + * properties that are needed for the {@link SparseOps} methods. + * + * @param the type of the SparseTensor's values. */ public class SparseTensor { private final Operand indices; From 19c881c3e949f3888ba324dc15758d3db02b946f Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 15:50:06 -0500 Subject: [PATCH 22/97] Fix 'import *' --- .../framework/metrics/impl/MetricsHelper.java | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index d8dea6d062e..8a71da87756 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -22,15 +22,31 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; + +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Rank; import org.tensorflow.op.core.Stack; -import org.tensorflow.op.core.*; +import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; import org.tensorflow.op.nn.TopK; -import org.tensorflow.types.*; + + +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import java.util.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; @@ -81,7 +97,7 @@ public static Op assertBroadcastable( && !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") - .withControlDependencies(Collections.EMPTY_LIST) + .withControlDependencies(java.util.Collections.EMPTY_LIST) .noOp(); } if (weightsRankStatic != valuesRankStatic) { @@ -793,7 +809,7 @@ public static Operand confusionMatrix( Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls); tLabels = tfc.identity(tLabels); - tfc = tf.withSubScope("confusionMatrixPredicitons").withControlDependencies(predictionControls); + tfc = tf.withSubScope("confusionMatrixPredictions").withControlDependencies(predictionControls); tPredictions = tfc.identity(tPredictions); Operand shape = tf.stack(Arrays.asList(numClasses, numClasses)); @@ -803,6 +819,7 @@ public static Operand confusionMatrix( SparseTensor cmSparse = new SparseTensor<>(indices, values, shape); Operand zeroMatrix = tf.zeros(shape, type); + return tf.sparse.sparseTensorDenseAdd( cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix); } From ebecb5edc15d10ff357744d58096105fd2b586d2 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 15:51:14 -0500 Subject: [PATCH 23/97] Fix casts --- .../tensorflow/framework/metrics/Metrics.java | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index e4cc9c3aa3d..bcf2ea4c880 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -15,13 +15,14 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** Helper class with built-in metrics functions. */ public class Metrics { @@ -49,8 +50,8 @@ public class Metrics { */ public static Operand topKCategoricalAccuracy( Ops tf, Operand labels, Operand predictions, long k) { - Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); - return CastHelper.cast( + Operand fPredictions = cast(tf, predictions, TFloat32.class); + return cast( tf, tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); @@ -81,15 +82,13 @@ public static Operand topKCategoricalAccuracy( @SuppressWarnings("unchecked") public static Operand sparseTopKCategoricalAccuracy( Ops tf, Operand labels, Operand predictions, int k) { - Operand tLabels; - if (labels.type() != predictions.type()) - tLabels = CastHelper.cast(tf, labels, predictions.type()); - else tLabels = (Operand) labels; + Operand tLabels = cast(tf, labels, predictions.type()); + int predictionsRank = predictions.shape().numDimensions(); int labelsRank = tLabels.shape().numDimensions(); - Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); + Operand castPredictions = cast(tf, predictions, TFloat32.class); if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { if (predictionsRank > 2) { castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); @@ -98,9 +97,9 @@ public static Operand sparseTopKCatego tLabels = tf.shape.flatten(tLabels); } } - return CastHelper.cast( + return cast( tf, - tf.nn.inTopK(castPredictions, CastHelper.cast(tf, tLabels, TInt32.class), tf.constant(k)), + tf.nn.inTopK(castPredictions, cast(tf, tLabels, TInt32.class), tf.constant(k)), predictions.type()); } } From 478441ea8072e3ac4935953753ce93d039f66fae Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 19:03:25 -0500 Subject: [PATCH 24/97] Reformat code --- .../org/tensorflow/framework/metrics/AUC.java | 36 +++++++++---------- .../framework/metrics/impl/MetricsHelper.java | 25 +++++-------- 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index cae67dbd4f0..9fbb3a3ad09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -107,9 +107,7 @@ public class AUC extends Metric { private final Map> initializers = new HashMap<>(); private final Class type; - /** - * The size of the label dimension. - */ + /** The size of the label dimension. */ private Integer numLabels; private Operand labelWeights; @@ -117,7 +115,7 @@ public class AUC extends Metric { /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * - * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single * class dimension within each example. */ private Variable truePositives; @@ -125,7 +123,7 @@ public class AUC extends Metric { /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * - * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single * class dimension within each example. */ private Variable falsePositives; @@ -133,7 +131,7 @@ public class AUC extends Metric { /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * - * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single * class dimension within each example. */ private Variable trueNegatives; @@ -141,7 +139,7 @@ public class AUC extends Metric { /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * - * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single * class dimension within each example. */ private Variable falseNegatives; @@ -549,8 +547,8 @@ public AUC( * * @param tf The TensorFlow Ops * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. - * @param numThresholds the number of thresholds to use when discretizing the roc curve. - * This includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2. + * @param numThresholds the number of thresholds to use when discretizing the roc curve. This + * includes the bracketing 0 and 1 thresholds, so the value must be &GE; 2. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link * AUCCurve#PR} for the Precision-Recall-curve. * @param summationMethod Specifies the Riemann summation method used @@ -563,10 +561,10 @@ public AUC( * if the data should be flattened into a single label before AUC computation. In the latter * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an * individual data point. Should be set to false for multi-class data. - * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When - * multiLabel is true, the weights are applied to the individual label AUCs when - * they are averaged to produce the multi-label AUC. When it's false, they are used to weight - * the individual label predictions in computing the confusion matrix on the flattened data. + * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When + * multiLabel is true, the weights are applied to the individual label AUCs when they + * are averaged to produce the multi-label AUC. When it's false, they are used to weight the + * individual label predictions in computing the confusion matrix on the flattened data. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. @@ -714,15 +712,13 @@ private Map> build(Shape shape) { /** * Creates a List of Operations to update the metric state based on labels and predictions. * - * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more - * class dimensions, and L1 is a potential extra dimension of size 1 that - * would be squeezed. Will be cast to T. If - * {@link #multiLabel} or if {@link #labelWeights} != null, - * then Cx must be a single dimension. + * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class + * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be + * cast to T. If {@link #multiLabel} or if {@link #labelWeights} != null + * , then Cx must be a single dimension. * @param predictions the predictions shape (N, Cx, P1?). Will be cast to T. * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to - * T. - * + * T. * @return a List of Operations to update the metric state */ @Override diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 8a71da87756..cf1755cad56 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -31,7 +31,6 @@ import org.tensorflow.op.math.Mean; import org.tensorflow.op.nn.TopK; - import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -40,7 +39,6 @@ import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -313,25 +311,23 @@ public static List assertShapes( * * @param tf the TensorFlow Ops * @param variablesToUpdate Map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. If multiLabel is - * false then all shapes are (T), where T is the number of thresholds. If - * multiLabel is true then all shapes are (T, C0), where C0 is the number - * of classes. + * corresponding variables to update as values. If multiLabel is false then all + * shapes are (T), where T is the number of thresholds. If multiLabel is true + * then all shapes are (T, C0), where C0 is the number of classes. * @param varInitializers Map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to initializer the corresponding variables from * variablesToUpdate. - * @param labels the labels, will be cast to {@link TBool} - * shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more - * class dimensions, and L1 is a potential extra dimension of size 1 that - * would be squeezed. If multiLabel or if - * labelWeights != null, then Cx must be a single dimension. + * @param labels the labels, will be cast to {@link TBool} shape (N, Cx, L1?) where N is the + * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra + * dimension of size 1 that would be squeezed. If multiLabel or if + * labelWeights != null, then Cx must be a single dimension. * @param predictions the predictions shape (N, Cx, P1?). * @param thresholds thresholds in `the range [0, 1], or {@link #NEG_INF} (used when * topK is set) * @param topK Optional, used only if multiLabel, indicates that only the top k * predictions should be considered. May be null. - * @param classIndex Optional, limits the prediction and labels to the specified class. - * The classIndex is an integer index into the first dimension of Cx. + * @param classIndex Optional, limits the prediction and labels to the specified class. The + * classIndex is an integer index into the first dimension of Cx. * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as * labels, and must be broadcast to labels (i.e., all dimensions * must be either 1, or the same as the corresponding labels @@ -510,13 +506,11 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), // else (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); - // if multilabel, then shape (T, N, D0) // else (T, ND) Operand dataTilesShape = tf.stack(dataTiles); Operand predsTiled = tf.tile(predictionsExtraDim, dataTilesShape); - // Compare predictions and threshold. Operand predIsPos = tf.math.greater(predsTiled, threshTiled); // Tile labels by number of thresholds @@ -819,7 +813,6 @@ public static Operand confusionMatrix( SparseTensor cmSparse = new SparseTensor<>(indices, values, shape); Operand zeroMatrix = tf.zeros(shape, type); - return tf.sparse.sparseTensorDenseAdd( cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix); } From e9bea47e12304fe054410d54ebe119f6f1a26c2f Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 19:04:27 -0500 Subject: [PATCH 25/97] Reformat code --- .../tensorflow/framework/metrics/AUCTest.java | 6 +++-- .../framework/metrics/BinaryAccuracyTest.java | 3 +-- .../metrics/CategoricalAccuracyTest.java | 9 +++---- .../framework/metrics/PrecisionTest.java | 16 ++++--------- .../framework/metrics/RecallTest.java | 24 +++++++------------ 5 files changed, 21 insertions(+), 37 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java index 88825b5f32e..857a5c93f7a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -23,7 +23,9 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.tensorflow.framework.utils.CastHelper.cast; public class AUCTest { @@ -199,7 +201,7 @@ public void testWeightedRocMinoring() { session.run(update); Operand result = instance.result(); - float expectedResult = ( 0.5714285f + 0f * 0f); + float expectedResult = (0.5714285f + 0f * 0f); session.evaluate(expectedResult, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java index e8d8350dcdc..d203815f4ab 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java @@ -155,8 +155,7 @@ public void testVariableState() { public void testBinaryAccuracyAThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryAccuracy instance = - new BinaryAccuracy<>(tf, 0.7f, 1001L, TFloat32.class); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 0.7f, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 1, 0, 0}; float[] predArray = {0.9f, 0.6f, 0.4f, 0.8f}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java index 83990cbaebb..aea2e4e0d6e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java @@ -31,8 +31,7 @@ public class CategoricalAccuracyTest { public void testCorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalAccuracy instance = - new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { 0, 0, 1, @@ -60,8 +59,7 @@ public void testCorrect() { public void testSampleWeight() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalAccuracy instance = - new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { 0, 0, 1, @@ -92,8 +90,7 @@ public void testSampleWeight() { public void testVariableState() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalAccuracy instance = - new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { 0, 0, 1, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java index 148ca520d3f..cfe5b483e2b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java @@ -16,7 +16,6 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.impl.MetricsHelper; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -203,8 +202,7 @@ public void testUnweightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK to 3 - Precision instance = - new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); @@ -220,8 +218,7 @@ public void testWeightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK to 3 - Precision instance = - new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[] {0.2f, 0.1f, 0.4f, 0f, 0.2f}); @@ -249,8 +246,7 @@ public void testUnweightedClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set classId to 2 - Precision instance = - new Precision<>(tf, null, null, 2, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, null, null, 2, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); @@ -290,8 +286,7 @@ public void testUnweightedTopKAndClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK and classId to 2 - Precision instance = - new Precision<>(tf, null, 2, 2, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, null, 2, 2, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0f, 0.2f}}); @@ -321,8 +316,7 @@ public void testUnweightedTopKAndThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK to 2 - Precision instance = - new Precision<>(tf, 0.7f, 2, null, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, 0.7f, 2, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java index b9d067a6ed2..bd9fbb1ab66 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java @@ -150,8 +150,7 @@ public void testDivByZero() { public void testUnweightedWithThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{1, 0, 0.6f, 0}}); @@ -169,8 +168,7 @@ public void testUnweightedWithThreshold() { public void testWeightedWithThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); @@ -192,8 +190,7 @@ public void testWeightedWithThreshold() { public void testMultipleUpdates() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); @@ -215,8 +212,7 @@ public void testMultipleUpdates() { public void testUnweightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand labels = tf.constant(new float[][] {{0f, 1f, 1f, 0f, 0f}}); @@ -233,8 +229,7 @@ public void testUnweightedTopK() { public void testWeightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 1}}); @@ -262,8 +257,7 @@ public void testWeightedTopK() { public void testUnweightedClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, null, null, 2, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, null, null, 2, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); @@ -296,8 +290,7 @@ public void testUnweightedClassId() { public void testUnweightedTopKAndClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, null, 2, 2, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, null, 2, 2, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0, 0.2f}}); @@ -324,8 +317,7 @@ public void testUnweightedTopKAndClassId() { public void testUnweightedTopKAndThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, 0.7f, 2, null, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, 0.7f, 2, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); From 8ae78cd120b3b58b512440a89cfcdf71639ccc68 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 19:24:24 -0500 Subject: [PATCH 26/97] Fix javadoc change >= to ≥ --- .../main/java/org/tensorflow/framework/metrics/AUC.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 9fbb3a3ad09..420b1567496 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -26,7 +26,12 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -548,7 +553,7 @@ public AUC( * @param tf The TensorFlow Ops * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. * @param numThresholds the number of thresholds to use when discretizing the roc curve. This - * includes the bracketing 0 and 1 thresholds, so the value must be &GE; 2. + * includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link * AUCCurve#PR} for the Precision-Recall-curve. * @param summationMethod Specifies the Riemann summation method used From b56344fd7d56dae9fc1aa8e97407cd03d195e249 Mon Sep 17 00:00:00 2001 From: deansher Date: Thu, 18 Mar 2021 07:53:34 -0400 Subject: [PATCH 27/97] Revised and improved internal docs in MetricsHelper.java --- .../framework/metrics/impl/MetricsHelper.java | 122 +++++++++++------- 1 file changed, 75 insertions(+), 47 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index cf1755cad56..32413aa87af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -304,45 +304,47 @@ public static List assertShapes( * will repeat the same for every threshold. * *

For estimation of these metrics over a stream of data, the function creates an `update_op` - * operation that updates the given variables. + * operation that updates the given variables.

* - *

If sampleWeight is null, weights default to 1. Use weights of 0 to - * mask values. + *

labels, predictions, and sampleWeight tensors are + * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. + * sampleWeight is then broadcast to the shape of predictions.

* * @param tf the TensorFlow Ops - * @param variablesToUpdate Map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. If multiLabel is false then all - * shapes are (T), where T is the number of thresholds. If multiLabel is true - * then all shapes are (T, C0), where C0 is the number of classes. - * @param varInitializers Map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding initializer Operands to initializer the corresponding variables from - * variablesToUpdate. - * @param labels the labels, will be cast to {@link TBool} shape (N, Cx, L1?) where N is the + * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding variables to update as values. If multiLabel, then the + * variable shapes are (T, D), where T is the number of thresholds and D is the number of + * classes (after slicing by classIndex, if provided). + * If multiLabels, then the variable shapes are (T). + * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding initializer Operands to for variablesToUpdate. + * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra - * dimension of size 1 that would be squeezed. If multiLabel or if - * labelWeights != null, then Cx must be a single dimension. - * @param predictions the predictions shape (N, Cx, P1?). - * @param thresholds thresholds in `the range [0, 1], or {@link #NEG_INF} (used when - * topK is set) - * @param topK Optional, used only if multiLabel, indicates that only the top k - * predictions should be considered. May be null. - * @param classIndex Optional, limits the prediction and labels to the specified class. The - * classIndex is an integer index into the first dimension of Cx. - * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as - * labels, and must be broadcast to labels (i.e., all dimensions - * must be either 1, or the same as the corresponding labels - * dimension). + * dimension of size 1 that would be squeezed. + * @param predictions the predictions shape (N, Cx, P1?) + * @param thresholds thresholds in the range [0, 1], or {@link #NEG_INF} is used + * when topK is set + * @param topK optional, indicates that only the top k predictions should be considered. + * Applied before possibly slicing by classIndex. + * @param classIndex optional, limits the prediction and labels to the specified class. + * This is an integer index into the first dimension of Cx. + * @param sampleWeight optional Tensor that is aligned with labels and predictions + * as explained above. Use weights of 0 to mask values. * @param multiLabel indicates whether multidimensional prediction/labels should be treated as * multilabel responses, or flattened into a single label. When true, the values of * variablesToUpdate must have a second dimension equal to the number of labels and * predictions per example, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES - * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. + * without explicit multilabel handling (i.e. when the data is to be flattened). + * Must have shape (Dx), which is the same as (Cx) referenced above, except that if + * classIndex is provided, then the final dimension of Dx is 1. These weights + * will be broadcast across the 0th dimension (the examples dimension) of + * predictions. May be null. Must be null if multiLabel. * @param the data type for the variables * @throws IllegalArgumentException If predictions and labels have * mismatched shapes, or if sampleWeight is not nulland its shape - * doesn't match predictions + * doesn't match predictions, or if multiLabel && labelWeights != null. * @return an op to update the given confusion matrix variables. */ @SuppressWarnings({"unchecked", "rawtypes"}) @@ -370,12 +372,25 @@ public static List updateConfusionMatrixVariables( Operand tPredictions = predictions; Operand tSampleWeight = sampleWeight; + // We will tile data for threshold comparisons. We want a cross product of thresholds and + // predictions/labels: + // In the multilabel case, we want a data shape of (T, N, D). + // else (T, ND). + // where + // T is numThresholds (the size of the 0th dimension of thresholds) + // N is the number of examples (the 0th dimension of labels and predictions) + // Dx == Cx except that if classIndex != null, + // then the last dimension of Dx is size 1 + // D is the product of all Dx + // ND is N * D + // size of the 0th dimension of thresholds // reshape to scalar for operations later. Operand numThresholds = tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); - // true if we will process thresholds as one-dimensional (possibly because we flatten them) + // if multilabel, then (rank(thresholds) == 1) + // else true Operand oneThresh; if (multiLabel) { oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds)); @@ -407,9 +422,9 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), LossTuple result = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions, tSampleWeight); - tPredictions = result.getTarget(); - tLabels = result.getLabels(); - tSampleWeight = result.getSampleWeights(); + tPredictions = result.getTarget(); // shape (N, Cx) + tLabels = result.getLabels(); // shape (N, Cx) + tSampleWeight = result.getSampleWeights(); // broadcastable to (N, Dx) if (!tPredictions.shape().isCompatibleWith(tLabels.shape())) throw new IllegalArgumentException( @@ -422,6 +437,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), } if (classIndex != null) { + // Slice to new shapes (N, Dx) tLabels = tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(1)); tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1)); } @@ -430,7 +446,8 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand numExamples = tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); - // number of labels (or predictions) per example + // number of labels (and predictions) per example (after possibly slicing by classIndex) + // In the notation we are using for comments, we'll call this D. Operand numLabels = tf.select( tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), @@ -441,22 +458,11 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); - // If we will treat thresholds as one-dimensional (always true as of this writing), - // then threshLabelTile is the number of labels (or predictions) per sample. Else it is 1. + // threshLabelTile == numLabels except in one case: + // if multilabel and rank(thresholds) != 1, then threshLabelTile is 1 Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); - ///////// - // Tile data for threshold comparisons, which is a cross product of thresholds and - // predictions/labels. - // - // In the multilabel case, we want a data shape of (T, N, D0). - // else (T, ND). - // where T is numThresholds - // Dx == Cx except that D0 == 1 if classIndex != null - // ND is the product of N and all Dx. - // In these comments, we refer to all indices beyond the threshold index as a "data position". - - // if multilabel, then shape (1, N, D0) + // if multilabel, then shape (1, N, Dx) // else shape (1, ND), Operand predictionsExtraDim; Operand labelsExtraDim; @@ -498,17 +504,22 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), dataTiles = Arrays.asList(numThresholds, tf.constant(1)); } + // if multilabel, then shape (T, 1, T*) + // else shape (T, T*) + // where T* is the product of all threshold dimension sizes beyond 0 Operand thresholdsReshaped = tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); + Operand threshTilesShape = tf.stack(threshTiles); // if multilabel, then shape (T, N, threshLabelTile) // else (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); - // if multilabel, then shape (T, N, D0) - // else (T, ND) Operand dataTilesShape = tf.stack(dataTiles); + + // if multilabel, then shape (T, N, D) + // else (T, ND) Operand predsTiled = tf.tile(predictionsExtraDim, dataTilesShape); // Compare predictions and threshold. @@ -518,17 +529,30 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand weightsTiled; if (tSampleWeight != null) { tSampleWeight = tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); + // if multilabel, then + // reshape tSampleWeight to (1, N, threshLabelTile) + // tile the result into shape (T, N, threshLabelTile) + // where threshLabelTile is typically D + // else + // reshape tSampleWeight to (1, ND) + // tile the result into shape (T, ND) weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape); } else { weightsTiled = null; } if (labelWeights != null) { + // Change shape to (1, Dx). Operand lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0)); + // Broadcast to shape (N, Dx). lLabelWeights = tf.broadcastTo(lLabelWeights, tPredictions); + + // If multilabel: shape (T, N, D) + // else: shape (T, ND) Operand labelWeightsTiled = tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles)); + if (weightsTiled == null) { weightsTiled = labelWeightsTiled; } else { @@ -606,6 +630,10 @@ private static Operand weightedAssignAdd( Operand lWeights = cast(tf, weights, type); labelAndPred = tf.math.mul(labelAndPred, lWeights); } + // if multilabel: + // sum across examples, leaving shape (T, D) + // else: + // sum across ND, leaving shape (T) Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); Operand assignAdd; if (initializer != null) { From b4373dcc49477a77661828a9348b8f6c0cb8f68b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 18 Mar 2021 12:09:54 -0400 Subject: [PATCH 28/97] Fix spelling in JavaDoc --- .../java/org/tensorflow/framework/metrics/AUC.java | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 420b1567496..72a1f022b41 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -63,7 +63,7 @@ *

Usage:
* *

- * AUC m = new  org.tensorflow.framework.metrcis.AUC( tf, 3);
+ * AUC m = new  org.tensorflow.framework.metrics.AUC( tf, 3);
  * m.updateState( tf.constant(new float[] {0, 0, 1,1}),
  *          tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
  *
@@ -603,7 +603,7 @@ public AUC(
         if (t < 0.0f || t > 1.0f) {
           throw new IllegalArgumentException(
               String.format(
-                  "Threshold values must be in [0, 1]. Invalid values: %s",
+                  "Threshold values must be in range [0, 1], inclusive. Invalid values: %s",
                   Arrays.toString(thresholds)));
         }
       }
@@ -621,12 +621,7 @@ public AUC(
         thresholds[i] = (i + 1) * 1.0f / (this.numThresholds - 1);
       }
     }
-    // Add an endpoint "threshold" below zero and above one for either
-    // threshold method to account for floating point imprecision.
-    if (thresholds.length != this.numThresholds - 2) {
-      throw new IllegalArgumentException(
-          "Thresholds length must contain numThresholds - 2 entries");
-    }
+
     // Add an endpoint "threshold" below zero and above one for either
     // threshold method to account for floating point imprecisions.
     this.thresholds = new float[this.numThresholds];
@@ -754,7 +749,7 @@ public List updateStateList(
         symbols.add(new SymbolicShape<>(falseNegatives, "T", "L"));
       }
       if (getLabelWeights() != null) {
-        symbols.add(new SymbolicShape<>(getLabelWeights(), "L", ""));
+        symbols.add(new SymbolicShape<>(getLabelWeights(), "L"));
       }
       updateOperations.addAll(
           MetricsHelper.assertShapes(tf, symbols, "Number of labels is not consistent."));

From 99d261052faceed4955b401c82d4acf244d4553b Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Thu, 18 Mar 2021 12:11:33 -0400
Subject: [PATCH 29/97] Change assertShapes to use runtime sizes as Operands
 rather than use primitive long.

---
 .../tensorflow/framework/metrics/impl/MetricsHelper.java | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
index cf1755cad56..07a630560fe 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
@@ -255,7 +255,7 @@ public static List assertShapes(
           updateOperations.add(assertion);
         });
 
-    Map dict = new HashMap<>();
+    Map> dict = new HashMap<>();
 
     // check that each operand's dimension size equals the corresponding symbolic shape's dimensions
     // size
@@ -266,9 +266,10 @@ public static List assertShapes(
               .getSymbols()
               .forEach(
                   s -> {
-                    Long size = dict.get(s);
+                    Operand size = dict.get(s);
                     if (size == null) {
-                      size = symbol.getOperand().shape().size((int) ll.get());
+                      // save size for later checks
+                      size = tf.shape.size( symbol.getOperand(), tf.constant(ll.get()), TInt64.class);
                       dict.put(s, size);
                     }
                     Op assertion =
@@ -279,7 +280,7 @@ public static List assertShapes(
                                         symbol.getOperand(),
                                         tf.constant(ll.getAndIncrement()),
                                         TInt64.class),
-                                    tf.constant(size)),
+                                        size),
                                 Collections.singletonList(tf.constant(message)));
                     updateOperations.add(assertion);
                   });

From e379582de4708e576870ca63f6ce5232a4b12ab0 Mon Sep 17 00:00:00 2001
From: deansher 
Date: Fri, 19 Mar 2021 11:11:20 -0400
Subject: [PATCH 30/97] Tweaked internal docs in MetricsHelper.java

---
 .../tensorflow/framework/metrics/impl/MetricsHelper.java  | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
index 32413aa87af..1121c18f4b7 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
@@ -447,7 +447,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
         tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar()));
 
     // number of labels (and predictions) per example (after possibly slicing by classIndex)
-    // In the notation we are using for comments, we'll call this D.
+    // In the notation we are using for comments, this is D.
     Operand numLabels =
         tf.select(
             tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)),
@@ -512,8 +512,10 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
 
     Operand threshTilesShape = tf.stack(threshTiles);
 
-    // if multilabel, then shape (T, N, threshLabelTile)
-    //                      else (T, ND)
+    // if multilabel, then
+    //     if thresholds has rank > 1, then shape (T, N, T*)
+    //                                 else shape (T, N, D)
+    // else shape (T, ND)
     Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape);
 
     Operand dataTilesShape = tf.stack(dataTiles);

From 96dce5cdc494c4ee5b4c816510048030440c89ba Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Sun, 21 Mar 2021 08:09:04 -0400
Subject: [PATCH 31/97] Replace calls to tf.slice with private method slice to
 clean up code. Added methods isPositive and posivite to clarify what was
 being done

---
 .../org/tensorflow/framework/metrics/AUC.java | 158 ++++++++++++------
 1 file changed, 103 insertions(+), 55 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
index 72a1f022b41..b09154b46da 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
@@ -24,6 +24,7 @@
 import org.tensorflow.op.Ops;
 import org.tensorflow.op.core.Assign;
 import org.tensorflow.op.core.Variable;
+import org.tensorflow.types.TBool;
 import org.tensorflow.types.family.TNumber;
 
 import java.util.ArrayList;
@@ -780,65 +781,108 @@ public List updateStateList(
     return updateOperations;
   }
 
+  /**
+   * Gets the input with all positive numbers. Negative numbers are set to 0.
+   *
+   * @param input the input
+   * @return the input with all positive numbers.
+   */
+  private Operand positive(Operand input) {
+    return getTF().math.maximum(input, cast(getTF(), getTF().constant(0), input.type()));
+  }
+
+  /**
+   * Gets an operand that determines whether the input consists of each value is greater than 0.
+   *
+   * @param input the input
+   * @return an operand that determines whether the input consists of all values greater than 0.
+   */
+  private Operand isPositive(Operand input) {
+    return getTF().math.greater(input, cast(getTF(), getTF().constant(0), input.type()));
+  }
+
+  /**
+   * Extracts a slice from the input.
+   *
+   * @param input the input
+   * @param begin the beginning location of the slice
+   * @param size the size of the slice
+   * @return the slice
+   */
+  private Operand slice(Operand input, int begin, int size) {
+    return getTF()
+        .slice(input, getTF().constant(new int[] {begin}), getTF().constant(new int[] {size}));
+  }
+
   /**
    * Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
    *
+   * 

Note here we derive & use a closed formula not present in the paper as follows: + *

+   *     Precision = TP / (TP + FP) = TP / P
+   * 
+ *

Modeling all of TP (true positive), FP (false positive) and their sum + * P = TP + FP (predicted positive) as varying linearly within each interval + * [A, B] between successive thresholds, we get

+ *
+   *     Precision slope = dTP / dP
+   *                     = (TP_B - TP_A) / (P_B - P_A)
+   *                     = (TP - TP_A) / (P - P_A)
+   *     Precision = (TP_A + slope * (P - P_A)) / P
+   * 
+ *

The area within the interval is (slope / total_pos_weight) times + *

+   *       int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
+   *       int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
+   * 
+ * where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + *
+   *       int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
+   * 
+ * Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + *
+   *       slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
+   * 
+ * where dTP == TP_B - TP_A. + * Note that when P_A == 0 the above calculation simplifies into + *
+   *       int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
+   * 
+ * which is really equivalent to imputing constant precision throughout the + * first bucket having >0 true positives. + * * @return an approximation of the area under the P-R curve. + * @see The Relationship Between Precision-Recall and ROC Curves - Davis & Goadrich 2006 */ private Operand interpolatePRAuc() { // truePositives[:self.numThresholds - 1] Ops tf = getTF(); - Operand tp0 = - tf.slice( - truePositives, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})); + Operand tp0 = slice(truePositives, 0, getNumThresholds() - 1); // truePositives[1:] - Operand tp1 = - tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand tp1 = slice(truePositives, 1, -1); Operand dTP = tf.math.sub(tp0, tp1); Operand p = tf.math.add(truePositives, falsePositives); - Operand dP = - tf.math.sub( - tf.slice( - p, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})), - tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))); + Operand dP = tf.math.sub(slice(p, 0, getNumThresholds() - 1), slice(p, 1, -1)); Operand precisionSlope = - tf.math.divNoNan(dTP, tf.math.maximum(dP, tf.dtypes.cast(tf.constant(0), dP.type()))); + tf.math.divNoNan(dTP, positive(dP)); Operand intercept = - tf.math.sub( - tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1})), - tf.math.mul( - precisionSlope, - tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})))); + tf.math.sub(slice(truePositives, 1, -1), tf.math.mul(precisionSlope, slice(p, 1, -1))); Operand safePRatio = tf.select( tf.math.logicalAnd( - tf.math.greater( - tf.slice( - p, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})), - tf.dtypes.cast(tf.constant(0), p.type())), - tf.math.greater( - tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})), - tf.dtypes.cast(tf.constant(0), p.type()))), + isPositive(slice(p, 0, getNumThresholds() - 1)), isPositive(slice(p, 1, -1))), tf.math.divNoNan( - tf.slice( - p, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})), - tf.math.maximum( - tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})), - tf.dtypes.cast(tf.constant(0), p.type()))), - tf.onesLike(tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})))); + slice(p, 0, getNumThresholds() - 1), + positive(slice(p, 1, -1))), + tf.onesLike(slice(p, 1, -1))); - Operand fn1 = - tf.slice(falseNegatives, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand fn1 = slice(falseNegatives, 1, -1); Operand aucTotalPos = tf.math.mul( @@ -847,14 +891,15 @@ private Operand interpolatePRAuc() { Operand prAucIncrement = tf.math.divNoNan( aucTotalPos, - tf.math.maximum( - tf.math.add(tp1, fn1), tf.dtypes.cast(tf.constant(0), truePositives.type()))); + positive(tf.math.add(tp1, fn1))); if (isMultiLabel()) { Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0)); if (getLabelWeights() == null) { + //Evenly weighted average of the label AUCs. return MetricsHelper.mean(tf, byLabelAuc); } else { + // Weighted average of the label AUCs. return tf.math.divNoNan( tf.reduceSum(tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, byLabelAuc)), tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))); @@ -877,22 +922,27 @@ public Operand result() { Operand y; Operand recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); - if (getCurve() == AUCCurve.ROC) { - x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives)); - y = recall; - } else { // AUCCurve.PR - y = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); - x = recall; + switch (getCurve()) { + case ROC: + x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives)); + y = recall; + break; + case PR: + y = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + x = recall; + break; + default: + throw new IllegalArgumentException("Unexpected AUCCurve value: " + getCurve()); } // Find the rectangle heights based on `summationMethod`. // y[:self.numThresholds - 1] - Operand ySlice1 = - tf.slice(y, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); + Operand ySlice1 = slice(y, 0, getNumThresholds() - 1); // y[1:] - Operand ySlice2 = tf.slice(y, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand ySlice2 = slice(y, 1, -1); + - Operand heights = null; + Operand heights; switch (getSummationMethod()) { case INTERPOLATION: heights = tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type())); @@ -903,17 +953,16 @@ public Operand result() { case MAJORING: heights = tf.math.maximum(ySlice1, ySlice2); break; + default: + throw new IllegalArgumentException("Unexpected AUCSummationMethod value: " + getSummationMethod()); } if (isMultiLabel()) { Operand riemannTerms = tf.math.mul( tf.math.sub( - tf.slice( - x, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})), - tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))), + slice(x, 0, getNumThresholds() - 1), + slice(x, 1, -1)), heights); Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0)); @@ -928,9 +977,8 @@ public Operand result() { } } else { - Operand slice1 = - tf.slice(x, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); - Operand slice2 = tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand slice1 = slice(x,0, getNumThresholds() - 1); + Operand slice2 = slice(x, 1, -1); Operand sub = tf.math.sub(slice1, slice2); Operand operand = tf.math.mul(sub, heights); return tf.reduceSum(operand, allAxes(tf, operand)); From 9f4044a00344d306025965db2e33edd08534a004 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 08:09:48 -0400 Subject: [PATCH 32/97] Fix Javdoc, remove spurious y_pred. --- .../org/tensorflow/framework/metrics/CategoricalAccuracy.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index c0635746d4d..c3780cc6de2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -27,7 +27,7 @@ /** * Metric that calculates how often predictions matches one-hot labels. * - *

You can provide logits of classes as predictionsy_pred, since argmax + *

You can provide logits of classes as predictions, since argmax * of logits and probabilities are same. * *

This metric creates two local variables, total and count that are From 5e50e9547accda97d7cc779d9cedd61b4b9ac606 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 08:10:41 -0400 Subject: [PATCH 33/97] remove spurious cast --- .../org/tensorflow/framework/metrics/impl/MetricsHelper.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 07a630560fe..a754c93be46 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -604,8 +604,7 @@ private static Operand weightedAssignAdd( Operand labelAndPred = cast(tf, tf.math.logicalAnd(labels, predictions), type); if (weights != null) { - Operand lWeights = cast(tf, weights, type); - labelAndPred = tf.math.mul(labelAndPred, lWeights); + labelAndPred = tf.math.mul(labelAndPred, weights); } Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); Operand assignAdd; From 81fc9fd718d92b003cca32f2daad2752ec165620 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 08:11:20 -0400 Subject: [PATCH 34/97] correct comments for enums --- .../framework/metrics/impl/ConfusionMatrixEnum.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java index 281aa2072d0..bf3ade53d73 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java @@ -18,9 +18,9 @@ public enum ConfusionMatrixEnum { /** These are cases in which the prediction is true, and reality is true. */ TRUE_POSITIVES("tp"), - /** These are cases in which the prediction is false, and reality is true. */ - FALSE_POSITIVES("fp"), /** These are cases in which the prediction is true, and reality is false. */ + FALSE_POSITIVES("fp"), + /** These are cases in which the prediction is false, and reality is true. */ TRUE_NEGATIVES("tn"), /** These are cases in which the prediction is false, and reality is false. */ FALSE_NEGATIVES("fn"); From 743f416991135fea6b964a32b10dc8e6178d0573 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 16:21:24 -0400 Subject: [PATCH 35/97] Fix the documentation on TP, FP, TN, and FN --- .../framework/metrics/impl/ConfusionMatrixEnum.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java index bf3ade53d73..caa5f203f9f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java @@ -20,9 +20,9 @@ public enum ConfusionMatrixEnum { TRUE_POSITIVES("tp"), /** These are cases in which the prediction is true, and reality is false. */ FALSE_POSITIVES("fp"), - /** These are cases in which the prediction is false, and reality is true. */ - TRUE_NEGATIVES("tn"), /** These are cases in which the prediction is false, and reality is false. */ + TRUE_NEGATIVES("tn"), + /** These are cases in which the prediction is false, and reality is true. */ FALSE_NEGATIVES("fn"); private final String abbrev; From 5ed35f3326d252b0877d65f7e0162f120fcc2ae9 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 16:50:09 -0400 Subject: [PATCH 36/97] Added code comments to fitlerTopK. --- .../org/tensorflow/framework/metrics/impl/MetricsHelper.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 6d35d1a71c4..d34a7b25111 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -664,7 +664,9 @@ private static Operand weightedAssignAdd( private static Operand filterTopK(Ops tf, Operand x, int topK) { Class type = x.type(); Shape xShape = x.shape(); + // top has the same rank as x; the last dimension becomes indices of the topK features. TopK top = tf.nn.topK(x, tf.constant(topK), TopK.sorted(false)); + // oneHot has an additional dimension: the one-hot representation of each topK index. OneHot oneHot = tf.oneHot( top.indices(), @@ -672,6 +674,7 @@ private static Operand filterTopK(Ops tf, Operand x, i tf.constant(1), tf.constant(0), OneHot.axis(-1L)); + // Sum the one-hot representations along the last dimension of x. Operand topKMask = cast(tf, tf.reduceSum(oneHot, tf.constant(-2)), type); // x * top_k_mask + NEG_INF * (1 - top_k_mask) From 929ce7e8e60755d44212657b231b689918b99623 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:33:51 -0400 Subject: [PATCH 37/97] JavaDoc fixes and code cleanup and add code comments --- .../org/tensorflow/framework/metrics/AUC.java | 18 ++++++++++-------- .../tensorflow/framework/metrics/MeanIoU.java | 14 +++++++++++--- .../framework/metrics/MeanTensor.java | 15 ++++++++++++--- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index b09154b46da..01200fd39b4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -792,10 +792,10 @@ private Operand positive(Operand input) { } /** - * Gets an operand that determines whether the input consists of each value is greater than 0. + * Gets the truth value of whether {@code input > 0}, element-wise. * * @param input the input - * @return an operand that determines whether the input consists of all values greater than 0. + * @return the truth value of whether {@code input > 0}, element-wise. */ private Operand isPositive(Operand input) { return getTF().math.greater(input, cast(getTF(), getTF().constant(0), input.type())); @@ -864,23 +864,25 @@ private Operand interpolatePRAuc() { Operand dTP = tf.math.sub(tp0, tp1); Operand p = tf.math.add(truePositives, falsePositives); + Operand p0= slice(p, 0, getNumThresholds() - 1); + Operand p1= slice(p, 1, -1); - Operand dP = tf.math.sub(slice(p, 0, getNumThresholds() - 1), slice(p, 1, -1)); + Operand dP = tf.math.sub(p0,p1); Operand precisionSlope = tf.math.divNoNan(dTP, positive(dP)); Operand intercept = - tf.math.sub(slice(truePositives, 1, -1), tf.math.mul(precisionSlope, slice(p, 1, -1))); + tf.math.sub(tp1, tf.math.mul(precisionSlope, p1)); Operand safePRatio = tf.select( tf.math.logicalAnd( - isPositive(slice(p, 0, getNumThresholds() - 1)), isPositive(slice(p, 1, -1))), + isPositive(p0), isPositive(p1)), tf.math.divNoNan( - slice(p, 0, getNumThresholds() - 1), - positive(slice(p, 1, -1))), - tf.onesLike(slice(p, 1, -1))); + p0, + positive(p1)), + tf.onesLike(p1)); Operand fn1 = slice(falseNegatives, 1, -1); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 19b13ed391c..3cd3fd7c0ee 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -124,13 +124,17 @@ public List updateStateList( Operand sampleWeights) { Operand tLabels = cast(getTF(), labels, type); - if (tLabels.shape().numDimensions() > 1) tLabels = getTF().shape.flatten(tLabels); + if (tLabels.shape().numDimensions() > 1) { + tLabels = getTF().shape.flatten(tLabels); + } Operand tPredictions = cast(getTF(), predictions, type); - if (tPredictions.shape().numDimensions() > 1) + if (tPredictions.shape().numDimensions() > 1) { tPredictions = getTF().shape.flatten(tPredictions); + } Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; - if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) + if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { tSampleWeights = getTF().shape.flatten(tSampleWeights); + } Operand currentCM = MetricsHelper.confusionMatrix( @@ -149,6 +153,10 @@ public Operand result() { totalConfusionMatrix, tf.constant(0), cast(tf, tf.constant(0), totalConfusionMatrix.type())); + // for each class, the total predictions + total labels - true positives + // Observe that total predictions = tp + fp + // total labels = tp + fn + // So this is 2*tp + fp + fn - tp = tp + fp + fn Operand denominator = tf.math.add(sumOverRow, tf.math.sub(sumOverCol, truePositives)); Operand numValidEntries = tf.reduceSum( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index 3d6d8194aac..f01cb47b256 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -103,14 +103,19 @@ private boolean init(Shape shape) { } } - /** {@inheritDoc} */ + /** + * Accumulates statistics for computing the element-wise mean. + * + * @param values Per-example value. Input values must always have the same shape for all + * invocations of updateStateList. + * @param sampleWeights Optional weighting of each example. Defaults to 1 if null. + */ @Override public List updateStateList( Operand values, Operand sampleWeights) { Ops tf = getTF(); Operand tValues = cast(tf, values, type); - Operand tSampleWeights = null; - if (sampleWeights != null) tSampleWeights = cast(tf, sampleWeights, type); + Operand tSampleWeights = sampleWeights == null ? null : cast(tf, sampleWeights, type); boolean needsInitialization = init(values.shape()); @@ -123,13 +128,17 @@ public List updateStateList( Operand numValues = tf.onesLike(tValues); if (tSampleWeights != null) { + //Update dimensions of weights to match with values if possible. LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights); tValues = tuple.getTarget(); tSampleWeights = tuple.getSampleWeights(); try { + // Broadcast weights if possible. tSampleWeights = WeightsBroadcastOps.broadcastWeights(tf, tSampleWeights, tValues); } catch (IllegalArgumentException ex) { + // sampleWeights cannot be broadcast to values + // Reduce values to same ndim as weight array int ndim = values.shape().numDimensions(); int weightNdim = tSampleWeights.asOutput().shape().numDimensions(); int[] range = new int[ndim - weightNdim]; From 1b36693a4f3a578f0d5c389737b6e96ae1061b49 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:34:59 -0400 Subject: [PATCH 38/97] JavaDoc fixes and code cleanup and add code comments Remose shape flatten in updateStateList --- .../framework/metrics/MeanRelativeError.java | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index 4c48c0f88a7..b8cec2150b7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -33,8 +33,8 @@ * ultimately returned as mean relative error: an idempotent operation that simply divides total by * count. * - *

If sampleWeight is null, weights default to 1. Use sample_weight of - * 0 to mask * values. + *

If {@code sampleWeight} is null, weights default to 1. Use {@code sampleWeight} of + * 0 to mask values. * * @param The data type for the metric result */ @@ -124,27 +124,29 @@ protected MeanRelativeError( this.normalizer = normalizer; } - /** {@inheritDoc} */ + /** + * Accumulates metric statistics. + * + * @param labels The ground truth values. + * @param predictions The predicted values. Must be the same shape as the normalizer. + * @param sampleWeights Optional weighting of each example. A null value defaults to 1. Can be an {@code Operand} + * whose rank is either 0, or the same rank as {@code labels}, and must be broadcastable to + * {@code labels}. + * @return a List of Operations to update the metric state + */ @Override public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { Operand tLabels = cast(getTF(), labels, getResultType()); - if (tLabels.shape().numDimensions() > 1) tLabels = getTF().shape.flatten(tLabels); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - if (tPredictions.shape().numDimensions() > 1) - tPredictions = getTF().shape.flatten(tPredictions); + Operand tSampleWeights = + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); tPredictions = tuple.getTarget(); tLabels = tuple.getLabels(); - Operand tSampleWeights = - sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; - if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { - tSampleWeights = getTF().shape.flatten(tSampleWeights); - } tuple = LossesHelper.removeSqueezableDimensions(getTF(), normalizer, tPredictions); normalizer = tuple.getLabels(); From d301b01be4992f7506166b268a00458a0b2528af Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:35:46 -0400 Subject: [PATCH 39/97] Fix code in sparseTopKCategoricalAccuracy to reshape to proper dimensions --- .../java/org/tensorflow/framework/metrics/Metrics.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index bcf2ea4c880..3d4c262491f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -23,7 +23,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; -/** Helper class with built-in metrics functions. */ +/** Static methods for computing metrics. */ public class Metrics { /** @@ -84,6 +84,7 @@ public static Operand sparseTopKCatego Ops tf, Operand labels, Operand predictions, int k) { Operand tLabels = cast(tf, labels, predictions.type()); + // Flatten predictions to (batch_size, num_samples) and labels to (num_samples,) int predictionsRank = predictions.shape().numDimensions(); int labelsRank = tLabels.shape().numDimensions(); @@ -91,10 +92,13 @@ public static Operand sparseTopKCatego Operand castPredictions = cast(tf, predictions, TFloat32.class); if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { if (predictionsRank > 2) { - castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); + //y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) + castPredictions = tf.reshape(castPredictions, + tf.constant(castPredictions.shape().takeLast(1).prepend(Shape.UNKNOWN_SIZE))); } if (labelsRank > 1) { - tLabels = tf.shape.flatten(tLabels); + //y_true = array_ops.reshape(y_true, [-1]) + tLabels = tf.reshape(tLabels, tf.constant(Shape.of(Shape.UNKNOWN_SIZE))); } } return cast( From 020eb9c51632f8669c821fde829c54adf3e60dd7 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:36:14 -0400 Subject: [PATCH 40/97] Fix JavaDoc --- .../metrics/SparseTopKCategoricalAccuracy.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java index 7db290530cd..0fd600b4a0f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -22,7 +22,10 @@ import static org.tensorflow.framework.utils.CastHelper.cast; -/** @param The data type for the metric result */ +/** + * Computes how often integer targets are in the top `K` predictions. + * @param The data type for the metric result + * */ public class SparseTopKCategoricalAccuracy extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_K = 5; @@ -30,8 +33,7 @@ public class SparseTopKCategoricalAccuracy extends MeanMetric private final int k; /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of - * top elements to look at for computing accuracy. + * Creates a SparseTopKCategoricalAccuracy metric using {@link #DEFAULT_K} for the number of top elements. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. @@ -44,7 +46,7 @@ public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class ty } /** - * Creates a TopKCategoricalAccuracy metric + * Creates a SparseTopKCategoricalAccuracy metric. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. From 3ae642dbb0e766eab3476a5dd39fa94393574555 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:37:48 -0400 Subject: [PATCH 41/97] Fix JavaDoc --- .../framework/metrics/impl/MetricsHelper.java | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index d34a7b25111..c06616a6324 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -764,10 +764,37 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( /** * Computes the confusion matrix from predictions and labels. * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape {@code [n, n]}, where {@code n} is the + * number of valid labels for a given classification task. Both prediction and labels must be 1-D + * arrays of the same shape in order for this function to work. + * + *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum + * value in either predictions or labels. Class labels are expected to start at 0. For example, if + * {@code numClasses}` is 3, then the possible labels would be {@code [0, 1, 2]}. + * + *

If {@code weights} is not null, then each prediction contributes its corresponding weight to + * the total value of the confusion matrix cell. + * + *

For example: + * + *

+   *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
+   *          [[0 0 0 0 0]
+   *           [0 0 1 0 0]
+   *           [0 0 1 0 0]
+   *           [0 0 0 0 0]
+   *           [0 0 0 0 1]]
+   * 
+ * + * Note that the possible labels are assumed to be {@copde [0, 1, 2, 3,4]}, resulting in a 5x5 + * confusion matrix. + * * @param tf the TensorFlow Ops - * @param labels 1-D `Tensor` of real labels for the classification task. - * @param predictions 1-D `Tensor` of predictions for a given classification. - * @param numClasses The possible number of labels the classification task can have. + * @param labels 1-D {@code Operand} of real labels for the classification task. + * @param predictions 1-D {@code Operand} of predictions for a given classification. + * @param numClasses The possible number of labels the classification task can have. If this value + * is not provided, it will be calculated using both predictions and labels array. * @param weights optional weights to be applied to the confusion matrix * @param type Data type of the confusion matrix. * @param the type of Operands @@ -778,6 +805,7 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( * not have compatible shapes, or if weights is notnull and its * shape is not compatible with predictions. */ + // TODO should this be moved to FramnworkOps under math. public static Operand confusionMatrix( Ops tf, Operand labels, From 8b881b4a6e030aea1ab785dfd4c9a2a28b82d2b1 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 19:22:10 -0400 Subject: [PATCH 42/97] Fixed Javadoc, mainly to add shape requirements. Reformat code --- .../org/tensorflow/framework/metrics/AUC.java | 75 +++++++++---------- .../framework/metrics/AUCSummationMethod.java | 6 +- .../framework/metrics/Accuracy.java | 11 ++- .../framework/metrics/BinaryAccuracy.java | 8 +- .../framework/metrics/BinaryCrossentropy.java | 9 ++- .../metrics/CategoricalAccuracy.java | 16 +++- .../metrics/CategoricalCrossentropy.java | 8 +- .../framework/metrics/CategoricalHinge.java | 8 +- .../framework/metrics/CosineSimilarity.java | 18 ++++- .../tensorflow/framework/metrics/Hinge.java | 8 +- .../framework/metrics/KLDivergence.java | 8 +- .../framework/metrics/LogCoshError.java | 8 +- .../framework/metrics/MeanAbsoluteError.java | 8 +- .../metrics/MeanAbsolutePercentageError.java | 8 +- .../tensorflow/framework/metrics/MeanIoU.java | 15 +++- .../framework/metrics/MeanRelativeError.java | 12 +-- .../framework/metrics/MeanSquaredError.java | 21 +++++- .../metrics/MeanSquaredLogarithmicError.java | 8 +- .../framework/metrics/MeanTensor.java | 5 +- .../tensorflow/framework/metrics/Metrics.java | 8 +- .../tensorflow/framework/metrics/Poisson.java | 8 +- .../framework/metrics/Precision.java | 18 ++++- .../framework/metrics/PrecisionAtRecall.java | 1 + .../tensorflow/framework/metrics/Recall.java | 18 ++++- .../metrics/RootMeanSquaredError.java | 10 ++- .../metrics/SparseCategoricalAccuracy.java | 9 ++- .../SparseCategoricalCrossentropy.java | 13 +++- .../SparseTopKCategoricalAccuracy.java | 14 +++- .../framework/metrics/SquaredHinge.java | 10 ++- .../metrics/TopKCategoricalAccuracy.java | 8 +- .../impl/ConfusionMatrixConditionCount.java | 11 ++- .../framework/metrics/impl/MetricsHelper.java | 50 ++++++------- .../framework/metrics/impl/Reduce.java | 12 ++- .../impl/SensitivitySpecificityBase.java | 21 +++++- 34 files changed, 357 insertions(+), 114 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 01200fd39b4..3dbc6f22cec 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -715,11 +715,11 @@ private Map> build(Shape shape) { * * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be - * cast to T. If {@link #multiLabel} or if {@link #labelWeights} != null + * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} != null * , then Cx must be a single dimension. * @param predictions the predictions shape (N, Cx, P1?). Will be cast to T. * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to - * T. + * {@code }. * @return a List of Operations to update the metric state */ @Override @@ -795,7 +795,7 @@ private Operand positive(Operand input) { * Gets the truth value of whether {@code input > 0}, element-wise. * * @param input the input - * @return the truth value of whether {@code input > 0}, element-wise. + * @return the truth value of whether {@code input > 0}, element-wise. */ private Operand isPositive(Operand input) { return getTF().math.greater(input, cast(getTF(), getTF().constant(0), input.type())); @@ -818,41 +818,52 @@ private Operand slice(Operand input, int begin, int size) { * Interpolation formula inspired by section 4 of Davis & Goadrich 2006. * *

Note here we derive & use a closed formula not present in the paper as follows: + * *

    *     Precision = TP / (TP + FP) = TP / P
    * 
- *

Modeling all of TP (true positive), FP (false positive) and their sum - * P = TP + FP (predicted positive) as varying linearly within each interval - * [A, B] between successive thresholds, we get

+ * + *

Modeling all of TP (true positive), FP (false positive) and their sum P = TP + FP (predicted + * positive) as varying linearly within each interval [A, B] between successive thresholds, we get + * *

    *     Precision slope = dTP / dP
    *                     = (TP_B - TP_A) / (P_B - P_A)
    *                     = (TP - TP_A) / (P - P_A)
    *     Precision = (TP_A + slope * (P - P_A)) / P
    * 
+ * *

The area within the interval is (slope / total_pos_weight) times + * *

    *       int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
    *       int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
    * 
- * where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + * + * where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + * *
    *       int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
    * 
- * Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + * + * Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + * *
    *       slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
    * 
- * where dTP == TP_B - TP_A. - * Note that when P_A == 0 the above calculation simplifies into + * + * where dTP == TP_B - TP_A. Note that when P_A == 0 the above calculation simplifies into + * *
    *       int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
    * 
- * which is really equivalent to imputing constant precision throughout the - * first bucket having >0 true positives. + * + * which is really equivalent to imputing constant precision throughout the first bucket having >0 + * true positives. * * @return an approximation of the area under the P-R curve. - * @see The Relationship Between Precision-Recall and ROC Curves - Davis & Goadrich 2006 + * @see The Relationship Between + * Precision-Recall and ROC Curves - Davis & Goadrich 2006 */ private Operand interpolatePRAuc() { // truePositives[:self.numThresholds - 1] @@ -864,24 +875,19 @@ private Operand interpolatePRAuc() { Operand dTP = tf.math.sub(tp0, tp1); Operand p = tf.math.add(truePositives, falsePositives); - Operand p0= slice(p, 0, getNumThresholds() - 1); - Operand p1= slice(p, 1, -1); + Operand p0 = slice(p, 0, getNumThresholds() - 1); + Operand p1 = slice(p, 1, -1); - Operand dP = tf.math.sub(p0,p1); + Operand dP = tf.math.sub(p0, p1); - Operand precisionSlope = - tf.math.divNoNan(dTP, positive(dP)); + Operand precisionSlope = tf.math.divNoNan(dTP, positive(dP)); - Operand intercept = - tf.math.sub(tp1, tf.math.mul(precisionSlope, p1)); + Operand intercept = tf.math.sub(tp1, tf.math.mul(precisionSlope, p1)); Operand safePRatio = tf.select( - tf.math.logicalAnd( - isPositive(p0), isPositive(p1)), - tf.math.divNoNan( - p0, - positive(p1)), + tf.math.logicalAnd(isPositive(p0), isPositive(p1)), + tf.math.divNoNan(p0, positive(p1)), tf.onesLike(p1)); Operand fn1 = slice(falseNegatives, 1, -1); @@ -890,15 +896,12 @@ private Operand interpolatePRAuc() { tf.math.mul( precisionSlope, tf.math.add(dTP, tf.math.mul(intercept, tf.math.log(safePRatio)))); - Operand prAucIncrement = - tf.math.divNoNan( - aucTotalPos, - positive(tf.math.add(tp1, fn1))); + Operand prAucIncrement = tf.math.divNoNan(aucTotalPos, positive(tf.math.add(tp1, fn1))); if (isMultiLabel()) { Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0)); if (getLabelWeights() == null) { - //Evenly weighted average of the label AUCs. + // Evenly weighted average of the label AUCs. return MetricsHelper.mean(tf, byLabelAuc); } else { // Weighted average of the label AUCs. @@ -943,7 +946,6 @@ public Operand result() { // y[1:] Operand ySlice2 = slice(y, 1, -1); - Operand heights; switch (getSummationMethod()) { case INTERPOLATION: @@ -956,16 +958,13 @@ public Operand result() { heights = tf.math.maximum(ySlice1, ySlice2); break; default: - throw new IllegalArgumentException("Unexpected AUCSummationMethod value: " + getSummationMethod()); + throw new IllegalArgumentException( + "Unexpected AUCSummationMethod value: " + getSummationMethod()); } if (isMultiLabel()) { Operand riemannTerms = - tf.math.mul( - tf.math.sub( - slice(x, 0, getNumThresholds() - 1), - slice(x, 1, -1)), - heights); + tf.math.mul(tf.math.sub(slice(x, 0, getNumThresholds() - 1), slice(x, 1, -1)), heights); Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0)); if (getLabelWeights() == null) { @@ -979,7 +978,7 @@ public Operand result() { } } else { - Operand slice1 = slice(x,0, getNumThresholds() - 1); + Operand slice1 = slice(x, 0, getNumThresholds() - 1); Operand slice2 = slice(x, 1, -1); Operand sub = tf.math.sub(slice1, slice2); Operand operand = tf.math.mul(sub, heights); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java index 3887f687eea..735d97ecf09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java @@ -21,8 +21,10 @@ * @see Riemann summation method */ public enum AUCSummationMethod { - /** Apply mid-point summation scheme for {@link AUCCurve#ROC}. For {@link AUCCurve#PR}, interpolates (true/false) positives but not the ratio that - * is precision */ + /** + * Apply mid-point summation scheme for {@link AUCCurve#ROC}. For {@link AUCCurve#PR}, + * interpolates (true/false) positives but not the ratio that is precision + */ INTERPOLATION, /** Apply right summation for increasing intervals and left summation for decreasing intervals */ MAJORING, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java index 9548fb42c65..30787a9889b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -65,7 +66,15 @@ public Accuracy(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Calculates how often predictions equals labels. {@code labels} and {@code predictions} must + * have compatible shapes, see {@link Shape @isCompatibleWith}. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @throws IllegalArgumentException if predictions and labels shapes are not compatible. + * @return the loss + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index d2a414fdeb7..4f9a267d633 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -85,7 +85,13 @@ public BinaryAccuracy(Ops tf, String name, float threshold, long seed, Class setLoss(this); } - /** {@inheritDoc} */ + /** + * Calculates how often predictions match binary labels. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Binary accuracy values. shape = {@code [batch_size, d0, .. dN-1]} + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 48ee244eafb..57a6f75375d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -60,7 +60,14 @@ public BinaryCrossentropy( this.labelSmoothing = labelSmoothing; } - /** {@inheritDoc} */ + /** + * Computes the binary crossentropy loss between labels and predictions. + * + * @param labels the truth values or labels, has the same shape as predictions and shape = {@code + * [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Binary crossentropy loss value. shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index c3780cc6de2..55c3dc800e1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -27,8 +27,8 @@ /** * Metric that calculates how often predictions matches one-hot labels. * - *

You can provide logits of classes as predictions, since argmax - * of logits and probabilities are same. + *

You can provide logits of classes as predictions, since argmax of + * logits and probabilities are same. * *

This metric creates two local variables, total and count that are * used to compute the frequency with which predictions matches labels. @@ -73,7 +73,17 @@ public CategoricalAccuracy(Ops tf, String name, long seed, Class type) { super.setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the categorical crossentropy loss. + * + *

{@code predictions} and {@code labels} should be passed in as vectors of probabilities, + * rather than as labels. If necessary, use {@line Ops#oneHot} to expand {@code labels} as a + * vector. + * + * @param labels One-hot ground truth values. + * @param predictions tThe prediction values. + * @return Categorical accuracy values. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index b22e5415f79..a7e85ce5b02 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -99,7 +99,13 @@ public CategoricalCrossentropy( this.axis = axis; } - /** {@inheritDoc} */ + /** + * Computes the crossentropy loss between the labels and predictions. + * + * @param labels the truth values or labels, of one-hot true targets, same shape as predictions + * @param predictions the predictions + * @return Categorical crossentropy loss value. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 4266cc487c0..1f6d0fd002c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -45,7 +45,13 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the categorical hinge metric between {@code labels} and @{code predictions}. + * + * @param labels the truth values or labels, labels values are expected to be 0 or 1. + * @param predictions the predictions + * @return Categorical hinge loss values. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 840f255c5ab..230286a738f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -26,7 +26,17 @@ /** * A metric that computes the cosine similarity metric between labels and predictions. * + *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 + * indicates orthogonality and values closer to -1 indicate greater similarity. The values closer to + * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you + * try to maximize the proximity between predictions and targets. If either labels and predictions + * is a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and + * targets. + * + *

{@code loss = -sum(l2_norm(y_true) * l2_norm(y_pred))}
+ * * @param The data type for the metric result. + * @see Cosine Similarity */ public class CosineSimilarity extends MeanMetricWrapper implements LossMetric { @@ -76,7 +86,13 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the cosine similarity loss between labels and predictions. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @return the cosine similarity loss + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 46ccd2859ff..a2d110867b8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -44,7 +44,13 @@ public Hinge(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the hinge loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return the hinge loss between labels and predictions. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 9ffcd6189f1..155a891ccc2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -45,7 +45,13 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes Kullback-Leibler divergence metric between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return the loss with shape {@code [batch_size, d0, .. dN-1]} + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 59e24f57110..786847d4b32 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -45,7 +45,13 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Calculates the Logarithm of the hyperbolic cosine of the prediction error. + * + * @param labels Ground truth values, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Logcosh error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 1cc6d0b6f99..b38d0a809e1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -45,7 +45,13 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean absolute error loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Mean absolute error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 8c6720b58f6..22bcd0ab0eb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -45,7 +45,13 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean absolute percentage error loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Mean absolute percentage error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 3cd3fd7c0ee..70c2e6db8f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -116,7 +116,15 @@ public Assign getInitializer() { return initializer; } - /** {@inheritDoc} */ + /** + * Accumulates the confusion matrix statistics. + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either + * 0, or the same rank as labels, and must be broadcastable to labels. + * @return the Operands that updates totalConfusionMatrix variable + */ @Override public List updateStateList( Operand labels, @@ -130,12 +138,13 @@ public List updateStateList( Operand tPredictions = cast(getTF(), predictions, type); if (tPredictions.shape().numDimensions() > 1) { tPredictions = getTF().shape.flatten(tPredictions); - } + } Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { tSampleWeights = getTF().shape.flatten(tSampleWeights); - } + } + // Accumulate the prediction to current confusion matrix. Operand currentCM = MetricsHelper.confusionMatrix( getTF(), tLabels, tPredictions, getTF().constant(numClasses), tSampleWeights, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index b8cec2150b7..ac25183c0e5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -33,8 +33,8 @@ * ultimately returned as mean relative error: an idempotent operation that simply divides total by * count. * - *

If {@code sampleWeight} is null, weights default to 1. Use {@code sampleWeight} of - * 0 to mask values. + *

If {@code sampleWeight} is null, weights default to 1. Use {@code sampleWeight} + * of 0 to mask values. * * @param The data type for the metric result */ @@ -129,9 +129,9 @@ protected MeanRelativeError( * * @param labels The ground truth values. * @param predictions The predicted values. Must be the same shape as the normalizer. - * @param sampleWeights Optional weighting of each example. A null value defaults to 1. Can be an {@code Operand} - * whose rank is either 0, or the same rank as {@code labels}, and must be broadcastable to - * {@code labels}. + * @param sampleWeights Optional weighting of each example. A null value defaults to 1. Can be an + * {@code Operand} whose rank is either 0, or the same rank as {@code labels}, and must be + * broadcastable to {@code labels}. * @return a List of Operations to update the metric state */ @Override @@ -142,7 +142,7 @@ public List updateStateList( Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); Operand tSampleWeights = - sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); tPredictions = tuple.getTarget(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 3c4c79d39ba..fd8be29875e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -26,6 +26,19 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * + *

The {@code MeanSquaredError} class creates two local variables, {@code total} and {@code + * count} that are used to compute the mean squared error. This average is weighted by {@code + * weights}, and it is ultimately returned as the mean squared error: an idempotent operation that + * simply divides {@code total} by {@code count}. + * + *

For estimation of the metric over a stream of data, the function creates an update operation + * that updates these variables. Internally, a squared error operation computes the element-wise + * square of the difference between {@code predictions} and {@code labels}. Then the update + * operation increments {@code total} with the reduced sum of the product of {@code weights} and the + * squared error, and it increments {@code count} with the reduced sum of {@code weights}. + * + *

If {@code weights} is null, weights default to 1. Use weights of 0 to mask values. + * * @param The data type for the metric result. */ public class MeanSquaredError extends MeanMetricWrapper @@ -45,7 +58,13 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean squared error between the labels and predictions. + * + * @param labels the truth values or labels. Must be the same shape as predictions. + * @param predictions the predictions + * @return Computes the mean squared error between the labels and predictions. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index d525bb76648..4728cbab12f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -45,7 +45,13 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean squared logarithmic error between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Mean squared logarithmic error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index f01cb47b256..d88d7a4c1b4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -109,6 +109,8 @@ private boolean init(Shape shape) { * @param values Per-example value. Input values must always have the same shape for all * invocations of updateStateList. * @param sampleWeights Optional weighting of each example. Defaults to 1 if null. + * @throws IllegalArgumentException if the shape of {@code values} in each subsequent call is not + * the same shape as {@code values} set during the first call */ @Override public List updateStateList( @@ -117,6 +119,7 @@ public List updateStateList( Operand tValues = cast(tf, values, type); Operand tSampleWeights = sampleWeights == null ? null : cast(tf, sampleWeights, type); + // update the shape if it is the first call. boolean needsInitialization = init(values.shape()); if (!this.shape.equals(values.shape())) { @@ -128,7 +131,7 @@ public List updateStateList( Operand numValues = tf.onesLike(tValues); if (tSampleWeights != null) { - //Update dimensions of weights to match with values if possible. + // Update dimensions of weights to match with values if possible. LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights); tValues = tuple.getTarget(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 3d4c262491f..a33750ac3f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -92,12 +92,14 @@ public static Operand sparseTopKCatego Operand castPredictions = cast(tf, predictions, TFloat32.class); if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { if (predictionsRank > 2) { - //y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) - castPredictions = tf.reshape(castPredictions, + // y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) + castPredictions = + tf.reshape( + castPredictions, tf.constant(castPredictions.shape().takeLast(1).prepend(Shape.UNKNOWN_SIZE))); } if (labelsRank > 1) { - //y_true = array_ops.reshape(y_true, [-1]) + // y_true = array_ops.reshape(y_true, [-1]) tLabels = tf.reshape(tLabels, tf.constant(Shape.of(Shape.UNKNOWN_SIZE))); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 422fd4808ff..2e4bde8ec55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -44,7 +44,13 @@ public Poisson(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the Poisson loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Poisson loss value, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index bd536f16b29..5784cf46385 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -22,9 +22,14 @@ import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -290,7 +295,16 @@ private void init() { } } - /** {@inheritDoc} */ + /** + * Accumulates true positive and false positive statistics. + * + * @param labels the labels The ground truth values, with the same dimensions as predictions. Will + * be cast to {@link TBool}. + * @param predictions the predictions, each element must be in the range {@code [0, 1]}. + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override @SuppressWarnings("unchecked") public List updateStateList( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 299c649279f..4205f761e4b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -110,6 +110,7 @@ public PrecisionAtRecall( this.recall = recall; } + /** {@inheritDoc} */ @Override public Operand result() { Ops tf = getTF(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 54e9de0d9cf..ca5968d4f9d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -22,9 +22,14 @@ import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -321,7 +326,16 @@ public Op resetStates() { return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); } - /** {@inheritDoc} */ + /** + * Accumulates true positive and false negative statistics. + * + * @param labels the labels The ground truth values, with the same dimensions as predictions. Will + * be cast to {@link TBool}. + * @param predictions the predictions, each element must be in the range {@code [0, 1]}. + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override @SuppressWarnings("unchecked") public List updateStateList( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 9b4401964d7..721b95487c7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -60,7 +60,15 @@ public RootMeanSquaredError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); } - /** {@inheritDoc} */ + /** + * Accumulates root mean squared error statistics. + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override public List updateStateList( Operand labels, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 7bfa7fd6ee9..6dfdab48578 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -94,7 +94,13 @@ public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) super.setLoss(this); } - /** {@inheritDoc} */ + /** + * Calculates how often predictions matches integer labels. + * + * @param labels Integer ground truth values. + * @param predictions the predictions + * @return Sparse categorical accuracy values. + */ @Override public Operand call( Operand labels, Operand predictions) { @@ -106,6 +112,7 @@ public Operand call( long predictionsRank = predShape.numDimensions(); long labelsRank = labelsShape.numDimensions(); + // If the shape of labels is (num_samples, 1), squeeze to (num_samples,) if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE && labelsShape.size((int) labelsRank - 1) == 1) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 9949f0c6b60..04555d85b66 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -25,7 +25,10 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and - * predicted labels. \ + * predicted labels. + * + *

You can provide logits of classes as predictions, since argmax of logits and probabilities are + * same. * * @param The data type for the metric result. */ @@ -55,7 +58,13 @@ public SparseCategoricalCrossentropy( this.axis = axis; } - /** {@inheritDoc} */ + /** + * Calculates how often predictions matches integer labels. + * + * @param labels Integer ground truth values. + * @param predictions the predictions + * @return Sparse categorical accuracy values. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java index 0fd600b4a0f..29dc91298d3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -24,8 +24,9 @@ /** * Computes how often integer targets are in the top `K` predictions. + * * @param The data type for the metric result - * */ + */ public class SparseTopKCategoricalAccuracy extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_K = 5; @@ -33,7 +34,8 @@ public class SparseTopKCategoricalAccuracy extends MeanMetric private final int k; /** - * Creates a SparseTopKCategoricalAccuracy metric using {@link #DEFAULT_K} for the number of top elements. + * Creates a SparseTopKCategoricalAccuracy metric using {@link #DEFAULT_K} for the number of top + * elements. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. @@ -61,7 +63,13 @@ public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Clas setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes how often integer targets are in the top {@code K} predictions. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @return Sparse top K categorical accuracy value. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 19b3b1d0ac4..e2ff208b8f5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -44,7 +44,15 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the squared hinge loss between labels and predictions. + * + * @param labels The ground truth values. {@code labels} values are expected to be -1 or 1. If + * binary (0 or 1) labels are provided we will convert them to -1 or 1. shape = {@code + * [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Squared hinge loss values. shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index ad78e48bc34..9c8d6403a6b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -63,7 +63,13 @@ public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class t setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes how often targets are in the top {@code K} predictions. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @return Top K categorical accuracy value. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index 31e88b6bb31..63ea35df7f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -114,6 +114,7 @@ public ConfusionMatrixConditionCount( init(); } + /** Initialize the metric */ private void init() { Shape variableShape = Shape.of(this.thresholds.length); @@ -134,7 +135,15 @@ public Assign getInitializer() { return initializer; } - /** {@inheritDoc} */ + /** + * Accumulates the metric statistics. + * + * @param labels The ground truth values. + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override public List updateStateList( Operand labels, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index c06616a6324..f36aaa34d8f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -22,7 +22,6 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; - import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.OneHot; import org.tensorflow.op.core.Rank; @@ -30,7 +29,6 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; import org.tensorflow.op.nn.TopK; - import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -269,7 +267,8 @@ public static List assertShapes( Operand size = dict.get(s); if (size == null) { // save size for later checks - size = tf.shape.size( symbol.getOperand(), tf.constant(ll.get()), TInt64.class); + size = + tf.shape.size(symbol.getOperand(), tf.constant(ll.get()), TInt64.class); dict.put(s, size); } Op assertion = @@ -280,7 +279,7 @@ public static List assertShapes( symbol.getOperand(), tf.constant(ll.getAndIncrement()), TInt64.class), - size), + size), Collections.singletonList(tf.constant(message))); updateOperations.add(assertion); }); @@ -305,47 +304,48 @@ public static List assertShapes( * will repeat the same for every threshold. * *

For estimation of these metrics over a stream of data, the function creates an `update_op` - * operation that updates the given variables.

+ * operation that updates the given variables. * *

labels, predictions, and sampleWeight tensors are - * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. - * sampleWeight is then broadcast to the shape of predictions.

+ * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. + * sampleWeight is then broadcast to the shape of predictions. * * @param tf the TensorFlow Ops * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. If multiLabel, then the - * variable shapes are (T, D), where T is the number of thresholds and D is the number of - * classes (after slicing by classIndex, if provided). - * If multiLabels, then the variable shapes are (T). + * corresponding variables to update as values. If multiLabel, then the variable + * shapes are (T, D), where T is the number of thresholds and D is the number of classes + * (after slicing by classIndex, if provided). If multiLabels, then + * the variable shapes are (T). * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to for variablesToUpdate. * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra * dimension of size 1 that would be squeezed. * @param predictions the predictions shape (N, Cx, P1?) - * @param thresholds thresholds in the range [0, 1], or {@link #NEG_INF} is used - * when topK is set - * @param topK optional, indicates that only the top k predictions should be considered. - * Applied before possibly slicing by classIndex. - * @param classIndex optional, limits the prediction and labels to the specified class. - * This is an integer index into the first dimension of Cx. - * @param sampleWeight optional Tensor that is aligned with labels and predictions - * as explained above. Use weights of 0 to mask values. + * @param thresholds thresholds in the range [0, 1], or {@link #NEG_INF} is used when + * topK is set + * @param topK optional, indicates that only the top k predictions should be considered. Applied + * before possibly slicing by classIndex. + * @param classIndex optional, limits the prediction and labels to the specified class. This is an + * integer index into the first dimension of Cx. + * @param sampleWeight optional Tensor that is aligned with labels and predictions as + * explained above. Use weights of 0 to mask values. * @param multiLabel indicates whether multidimensional prediction/labels should be treated as * multilabel responses, or flattened into a single label. When true, the values of * variablesToUpdate must have a second dimension equal to the number of labels and * predictions per example, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES - * without explicit multilabel handling (i.e. when the data is to be flattened). - * Must have shape (Dx), which is the same as (Cx) referenced above, except that if - * classIndex is provided, then the final dimension of Dx is 1. These weights - * will be broadcast across the 0th dimension (the examples dimension) of - * predictions. May be null. Must be null if multiLabel. + * without explicit multilabel handling (i.e. when the data is to be flattened). Must have + * shape (Dx), which is the same as (Cx) referenced above, except that if classIndex + * is provided, then the final dimension of Dx is 1. These weights will be broadcast + * across the 0th dimension (the examples dimension) of predictions. May be null. + * Must be null if multiLabel. * @param the data type for the variables * @throws IllegalArgumentException If predictions and labels have * mismatched shapes, or if sampleWeight is not nulland its shape - * doesn't match predictions, or if multiLabel && labelWeights != null. + * doesn't match predictions, or if multiLabel && labelWeights != null + * . * @return an op to update the given confusion matrix variables. */ @SuppressWarnings({"unchecked", "rawtypes"}) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 2a26967b9f2..3b54ad2e08d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -106,11 +106,11 @@ public Op resetStates() { } /** - * Updates the metric variables based on the inputs. At least one input arg required for - * values, an optional additional input for the sampleWeights + * Updates the metric variables based on the inputs. At least one input arg required for {@}code + * values}, an optional additional input for the sampleWeights * * @param values the inputs to be passed to update state, this may not be null - * @param sampleWeights sample weights to be applied to values, may be null. + * @param sampleWeights sample weights to be applied to values, will default to 1 if null. * @return the result with a control dependency on update state Operands * @throws IllegalArgumentException if values is null */ @@ -129,13 +129,16 @@ public List updateStateList( if (sampleWeights != null) { tSampleWeights = cast(getTF(), sampleWeights, getResultType()); + // Update dimensions of weights to match with values if possible. LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, tSampleWeights); tValues = tuple.getTarget(); tSampleWeights = tuple.getSampleWeights(); try { + // Broadcast weights if possible tSampleWeights = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); } catch (IllegalArgumentException ex) { + // reduce values to same ndim as weight array // if we get here we have static shapes with either // different ranks or different dimension sizes. // first, reduce the values down to the rank of the samples @@ -162,7 +165,9 @@ public List updateStateList( getTF().assignAdd(total, cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; + // Exit early if the reduction doesn't have a denominator. if (reduction != MetricReduction.SUM) { + // Update `count` for reductions that require a denominator. switch (reduction) { case SUM_OVER_BATCH_SIZE: numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); @@ -183,6 +188,7 @@ public List updateStateList( throw new UnsupportedOperationException( String.format("reduction [%s] not implemented", reduction)); } + Operand totalCount = getTF().assignAdd(this.count, numValues); updateOperations.add(totalCount); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 84898d8a4d3..08b298294ac 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -10,7 +10,11 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -111,6 +115,11 @@ private void init() { } } + /** + * Gets a control dependency Op to initialize all the variables + * + * @return a control dependency Op to initialize all the variables + */ public Op initializeVariables() { List varInitializers = new ArrayList<>(); @@ -130,7 +139,15 @@ public Op initializeVariables() { return getTF().withControlDependencies(varInitializers).noOp(); } - /** {@inheritDoc} */ + /** + * Accumulates confusion matrix statistics. + * + * @param labels The ground truth values. + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override @SuppressWarnings("unchecked") public List updateStateList( From 001f0517f78ac4abcee06290384e9102846a7419 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 19:45:26 -0400 Subject: [PATCH 43/97] Fixed Javadoc errors. Change all xxxxxx to {@code xxxxxx} --- .../org/tensorflow/framework/metrics/AUC.java | 102 +++++++------- .../framework/metrics/Accuracy.java | 4 +- .../framework/metrics/BinaryAccuracy.java | 4 +- .../metrics/CategoricalAccuracy.java | 14 +- .../metrics/CategoricalCrossentropy.java | 26 ++-- .../framework/metrics/FalseNegatives.java | 24 ++-- .../framework/metrics/FalsePositives.java | 24 ++-- .../tensorflow/framework/metrics/MeanIoU.java | 6 +- .../framework/metrics/MeanRelativeError.java | 6 +- .../framework/metrics/Precision.java | 36 ++--- .../framework/metrics/PrecisionAtRecall.java | 2 +- .../tensorflow/framework/metrics/Recall.java | 18 +-- .../framework/metrics/RecallAtPrecision.java | 2 +- .../metrics/RootMeanSquaredError.java | 2 +- .../metrics/SensitivityAtSpecificity.java | 14 +- .../metrics/SparseCategoricalAccuracy.java | 4 +- .../metrics/SpecificityAtSensitivity.java | 16 +-- .../org/tensorflow/framework/metrics/Sum.java | 6 +- .../metrics/TopKCategoricalAccuracy.java | 2 +- .../framework/metrics/TrueNegatives.java | 24 ++-- .../framework/metrics/TruePositives.java | 24 ++-- .../impl/ConfusionMatrixConditionCount.java | 8 +- .../framework/metrics/impl/LossMetric.java | 2 +- .../metrics/impl/MeanMetricWrapper.java | 4 +- .../framework/metrics/impl/MetricsHelper.java | 129 +++++++++--------- .../framework/metrics/impl/Reduce.java | 4 +- .../framework/metrics/impl/SetsOps.java | 42 +++--- .../metrics/impl/WeightsBroadcastOps.java | 10 +- 28 files changed, 279 insertions(+), 280 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 3dbc6f22cec..bc5047d5855 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -40,25 +40,25 @@ /** * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. * - *

This metric creates four local variables, truePositives, trueNegatives - * , falsePositives and falseNegatives that are used to compute the + *

This metric creates four local variables, {@code truePositives}, {@code trueNegatives + * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the * AUC. To discretize the AUC curve, a linearly spaced set of thresholds is used to compute pairs of * recall and precision values. The area under the ROC-curve is therefore computed using the height * of the recall values by the false positive rate, while the area under the PR-curve is the * computed using the height of the precision values by the recall. * - *

This value is ultimately returned as auc, an idempotent operation that computes + *

This value is ultimately returned as {@code auc}, an idempotent operation that computes * the area under a discretized curve of precision versus recall values (computed using the - * aforementioned variables). The numThresholds variable controls the degree of + * aforementioned variables). The {@code numThresholds} variable controls the degree of * discretization with larger numbers of thresholds more closely approximating the true AUC. The - * quality of the approximation may vary dramatically depending on numThresholds. The - * thresholds parameter can be used to manually specify thresholds which split the + * quality of the approximation may vary dramatically depending on {@code numThresholds}. The + * {@code thresholds} parameter can be used to manually specify thresholds which split the * predictions more evenly. * - *

For best results, predictions should be distributed approximately uniformly in + *

For best results, {@code predictions} should be distributed approximately uniformly in * the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor - * if this is not the case. Setting summationMethod to minoring or - * majoring can help quantify the error in the approximation by providing lower or upper + * if this is not the case. Setting {@code summationMethod} to {@code minoring} or {@code + * majoring} can help quantify the error in the approximation by providing lower or upper * bound estimate of the AUC. * *

Usage:
@@ -155,8 +155,8 @@ public class AUC extends Metric { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link #DEFAULT_NUM_THRESHOLDS} for the numThresholds, {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for - * thresholds, false for multiLabel, and null for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed @@ -180,11 +180,11 @@ public AUC(Ops tf, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NUM_THRESHOLDS} for the * numThresholds, {@link AUCCurve#ROC} for the curve type, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, null for thresholds, - * false for multiLabel, and null for labelWeights. + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. @@ -206,8 +206,8 @@ public AUC(Ops tf, String name, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, null for thresholds, false for multiLabel, and - * null for labelWeights. + * summation method, {@code null} for thresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -233,8 +233,8 @@ public AUC(Ops tf, int numThresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, null for numThresholds, false for multiLabel, and - * null for labelWeights. + * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -259,11 +259,11 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric. using {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for - * thresholds, false for multiLabel, and null for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. * @param seed the seed for random number generation. An initializer created with a given seed @@ -285,13 +285,13 @@ public AUC(Ops tf, String name, int numThresholds, long seed, Class type) { } /** - * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link * AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the summation - * method, {@link #DEFAULT_NUM_THRESHOLDS} num thresholds, false for multiLabel, and - * null for labelWeights. + * method, {@link #DEFAULT_NUM_THRESHOLDS} num thresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. * @param seed the seed for random number generation. An initializer created with a given seed @@ -314,11 +314,11 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link AUCSummationMethod#INTERPOLATION} for - * the summation method, null for thresholds, false for multiLabel, and - * null for labelWeights. + * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -342,12 +342,12 @@ public AUC(Ops tf, String name, int numThresholds, AUCCurve curve, long seed, Cl } /** - * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link * AUCSummationMethod#INTERPOLATION} for the summation method, {@link #DEFAULT_NUM_THRESHOLDS} num - * thresholds, false for multiLabel, and null for labelWeights. + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -372,8 +372,8 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for - * thresholds, false for multiLabel, and null for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -399,9 +399,9 @@ public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) } /** - * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, false for multiLabel, - * and null for labelWeights. + * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, + * and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -428,7 +428,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) /** * Creates an AUC (Area under the curve) metric. using {@link #DEFAULT_NAME} for the metric name,, - * null for thresholds, false for multiLabel, and null for + * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for * labelWeights. * * @param tf The TensorFlow Ops @@ -453,7 +453,7 @@ public AUC( /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * null for numThresholds, false for multiLabel, and null + * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} * for labelWeights. * * @param tf The TensorFlow Ops @@ -487,11 +487,11 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using null for thresholds, - * false for multiLabel, and null for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code + * false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -513,11 +513,11 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using null for the numThresholds, - * false for multiLabel, and null for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -560,15 +560,15 @@ public AUC( * @param summationMethod Specifies the Riemann summation method used * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. This method - * automatically brackets the provided thresholds with a (-{@link #EPSILON}) + * automatically brackets the provided {@code thresholds} with a (-{@link #EPSILON}) * below and a (1 + {@link #EPSILON}) above. * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein * AUC is computed separately for each label and then averaged across labels, or (when false) * if the data should be flattened into a single label before AUC computation. In the latter * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an - * individual data point. Should be set to false for multi-class data. - * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When - * multiLabel is true, the weights are applied to the individual label AUCs when they + * individual data point. Should be set to {@code false} for multi-class data. + * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When {@code + * multiLabel} is true, the weights are applied to the individual label AUCs when they * are averaged to produce the multi-label AUC. When it's false, they are used to weight the * individual label predictions in computing the confusion matrix on the flattened data. * @param seed the seed for random number generation. An initializer created with a given seed @@ -715,9 +715,9 @@ private Map> build(Shape shape) { * * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be - * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} != null - * , then Cx must be a single dimension. - * @param predictions the predictions shape (N, Cx, P1?). Will be cast to T. + * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} {@code != null + * }, then Cx must be a single dimension. + * @param predictions the predictions shape (N, Cx, P1?). Will be cast to {@code T}. * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to * {@code }. * @return a List of Operations to update the metric state diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java index 30787a9889b..516d6c91ba6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -29,11 +29,11 @@ * Metric that calculates how often predictions equals labels. * *

This metric creates two local variables, total and count that are used to compute the - * frequency with which predictions matches labels. This frequency is + * frequency with which {@code predictions} matches {@code labels}. This frequency is * ultimately returned as binary accuracy: an idempotent operation that simply divides total by * count. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask + *

If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask * values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index 4f9a267d633..0e41699e165 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -26,11 +26,11 @@ * Metric that calculates how often predictions matches binary labels. * *

This metric creates two local variables, total and count that are used to compute the - * frequency with which predictions matches labels. This frequency is + * frequency with which {@code predictions} matches {@code labels}. This frequency is * ultimately returned as binary accuracy: an idempotent operation that simply divides total by * count. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask + *

If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask * values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index 55c3dc800e1..dece2d1cd50 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -27,18 +27,18 @@ /** * Metric that calculates how often predictions matches one-hot labels. * - *

You can provide logits of classes as predictions, since argmax of - * logits and probabilities are same. + *

You can provide {@code logits} of classes as {@code predictions}, since argmax of + * {@code logits} and probabilities are same. * - *

This metric creates two local variables, total and count that are - * used to compute the frequency with which predictions matches labels. + *

This metric creates two local variables, {@code total} and {@code count} that are + * used to compute the frequency with which {@code predictions} matches {@code labels}. * This frequency is ultimately returned as categorical accuracy: an idempotent operation that * simply divides total by count. * - *

predictions and labels should be passed in as vectors of + *

{@code predictions} and {@code labels} should be passed in as vectors of * probabilities, rather than as labels. If necessary, use {@link * org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand - * labels as a vector. + * {@code labels} as a vector. * *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. * @@ -77,7 +77,7 @@ public CategoricalAccuracy(Ops tf, String name, long seed, Class type) { * Computes the categorical crossentropy loss. * *

{@code predictions} and {@code labels} should be passed in as vectors of probabilities, - * rather than as labels. If necessary, use {@line Ops#oneHot} to expand {@code labels} as a + * rather than as labels. If necessary, use {@link Ops#oneHot} to expand {@code labels} as a * vector. * * @param labels One-hot ground truth values. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index a7e85ce5b02..58aa51f664c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -28,9 +28,9 @@ * labels. * *

This is the crossentropy metric class to be used when there are multiple label classes (2 or - * more). The labels should be given as a one_hot representation. eg., When labels values are - * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] - * . + * more). The labels should be given as a one_hot representation. eg., When labels values are {@code + * [2, 0, 1]}, the labels Operand contains = {@code [[0, 0, 1], [1, 0, 0], [0, 1, 0]] + * }. * * @param The data type for the metric result */ @@ -52,9 +52,9 @@ public class CategoricalCrossentropy extends MeanMetricWrappe * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 - * means that we will use a value of 0.1 for label 0 and 0.9 - * for label 1 + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} + * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 + * } for label {@code 1} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result @@ -73,13 +73,13 @@ public CategoricalCrossentropy( * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 - * means that we will use a value of 0.1 for label 0 and 0.9 - * for label 1 - * @param axis Int specifying the channels axis. axis={@link Losses#CHANNELS_LAST} - * corresponds to data format channels_last, and - * axis={@link Losses#CHANNELS_FIRST} corresponds to data format - * channels_first. + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} + * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 + * } for label {@code 1} + * @param axis Int specifying the channels axis. {@code axis={@link Losses#CHANNELS_LAST}} + * corresponds to data format {@code channels_last}, and {@code + * axis={@link Losses#CHANNELS_FIRST}} corresponds to data format {@code + * channels_first}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java index 39d33dda665..3db7fffc2e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false negatives. * - *

If sampleWeights is given, calculates the sum of the weights of false negatives. - * This metric creates one local variable, accumulator that is used to keep track of + *

If {@code sampleWeights} is given, calculates the sum of the weights of false negatives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of * the number of false negatives. * - *

If sampleWeights is null, weights default to 1. Use - * sampleWeights of 0 to mask values. + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. * * @param The data type for the metric result */ @@ -50,9 +50,9 @@ public FalseNegatives(Ops tf, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -66,9 +66,9 @@ public FalseNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -96,9 +96,9 @@ public FalseNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -113,9 +113,9 @@ public FalseNegatives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java index 3cf9fc0a5e9..551529b6179 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false positives. * - *

If sampleWeights is given, calculates the sum of the weights of false positives. - * This metric creates one local variable, accumulator that is used to keep track of + *

If {@code sampleWeights} is given, calculates the sum of the weights of false positives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of * the number of false positives. * - *

If sampleWeights is null, weights default to 1. Use - * sampleWeights of 0 to mask values. + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. * * @param The data type for the metric result */ @@ -50,9 +50,9 @@ public FalsePositives(Ops tf, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -66,9 +66,9 @@ public FalsePositives(Ops tf, float threshold, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -96,9 +96,9 @@ public FalsePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -113,9 +113,9 @@ public FalsePositives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 70c2e6db8f6..03c31b2bab8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -35,11 +35,11 @@ * *

Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, * which first computes the IOU for each semantic class and then computes the average over classes. - * IOU is defined as follows: IOU = true_positive - * / (true_positive + false_positive + false_negative). The predictions are accumulated in a + * IOU is defined as follows: {@code IOU = true_positive + * / (true_positive + false_positive + false_negative)}. The predictions are accumulated in a * confusion matrix, weighted by sample_weight and the metric is then calculated from it. * - *

If sampleWeight is null, weights default to 1. Use sample_weight of + *

If {@code sampleWeight} is {@code null}, weights default to 1. Use sample_weight of * 0 to mask values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index ac25183c0e5..acf28f5b2cc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -28,12 +28,12 @@ /** * Computes the mean relative error by normalizing with the given values. * - *

This metric creates two local variables, total and count that are - * used to compute the mean relative error. This is weighted by sampleWeight, and it is + *

This metric creates two local variables, {@code total} and {@code count} that are + * used to compute the mean relative error. This is weighted by {@code sampleWeight}, and it is * ultimately returned as mean relative error: an idempotent operation that simply divides total by * count. * - *

If {@code sampleWeight} is null, weights default to 1. Use {@code sampleWeight} + *

If {@code sampleWeight} is {@code null}, weights default to 1. Use {@code sampleWeight} * of 0 to mask values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index 5784cf46385..c56c53addf0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -36,22 +36,22 @@ /** * Computes the precision of the predictions with respect to the labels. * - *

The metric creates two local variables, truePositives and falsePositives - * that are used to compute the precision. This value is ultimately returned as precision, - * an idempotent operation that simply divides truePositives by the sum of - * truePositives and falsePositives. + *

The metric creates two local variables, {@code truePositives} and {@code falsePositives + * } that are used to compute the precision. This value is ultimately returned as precision, + * an idempotent operation that simply divides {@code truePositives} by the sum of {@code + * truePositives} and {@code falsePositives}. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of * 0 to mask values. * - *

If topK is set, the metric calculates precision as how often on average a class + *

If {@code topK} is set, the metric calculates precision as how often on average a class * among the top-k classes with the highest predicted values of a batch entry is correct and can be * found in the label for that entry. * - *

If classId is specified, the metric calculates precision by considering only the - * entries in the batch for which classId is above the thresholds and/or - * in the top-k highest predictions, and computing the fraction of them for which classId - * is indeed a correct label. + *

If {@code classId} is specified, the metric calculates precision by considering only the + * entries in the batch for which {@code classId} is above the {@code thresholds} and/or + * in the top-k highest predictions, and computing the fraction of them for which {@code classId + * } is indeed a correct label. * * @param The data type for the metric result */ @@ -103,7 +103,7 @@ public Precision(Ops tf, String name, long seed, Class type) { * values. * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -120,7 +120,7 @@ public Precision(Ops tf, float threshold, long seed, Class type) { * values. * * @param tf the TensorFlow Ops - * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -138,7 +138,7 @@ public Precision(Ops tf, float[] thresholds, long seed, Class type) { * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -156,7 +156,7 @@ public Precision(Ops tf, String name, float threshold, long seed, Class type) * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -172,7 +172,7 @@ public Precision(Ops tf, String name, float[] thresholds, long seed, Class ty * Creates a Precision Metric with a name of {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -193,7 +193,7 @@ public Precision( * Creates a Precision Metric with a name of {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops - * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -216,7 +216,7 @@ public Precision( * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -245,7 +245,7 @@ public Precision( * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 4205f761e4b..483b2523d74 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -29,7 +29,7 @@ * falseNegatives that are used to compute the precision at the given recall. The threshold for the * given recall value is computed and used to evaluate the corresponding precision. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of + *

If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of * 0 to mask values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index ca5968d4f9d..3886ec050b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -36,20 +36,20 @@ /** * Computes the recall of the predictions with respect to the labels. * - *

This metric creates two local variables, truePositives and falseNegatives - * , that are used to compute the recall. This value is ultimately returned as recall, an - * idempotent operation that simply divides truePositives by the sum of - * truePositives and falseNegatives. + *

This metric creates two local variables, {@code truePositives} and {@code falseNegatives + * }, that are used to compute the recall. This value is ultimately returned as recall, an + * idempotent operation that simply divides {@code truePositives} by the sum of {@code + * truePositives} and {@code falseNegatives}. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of * 0 to mask values. * - *

If topK is set, the metric calculates recall as how often on average a class + *

If {@code topK} is set, the metric calculates recall as how often on average a class * among the labels of a batch entry is in the top-k predictions. * - *

If classId is specified, the metric calculates recall by considering only the - * entries in the batch for which classId is in the label, and computing the fraction - * of them for which classId is above the threshold and/or in the top-k predictions. + *

If {@code classId} is specified, the metric calculates recall by considering only the + * entries in the batch for which {@code classId} is in the label, and computing the fraction + * of them for which {@code classId} is above the threshold and/or in the top-k predictions. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index fb6890d1e01..72eaedb9c4d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -34,7 +34,7 @@ * falseNegatives that are used to compute the recall at the given precision. The threshold for the * given precision value is computed and used to evaluate the corresponding recall. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of + *

If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of * 0 to mask values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 721b95487c7..3886428425b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -27,7 +27,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes root mean squared error metric between labels and predictions + * Computes root mean squared error metric between {@code labels} and {@code predictions} * . * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 2c7420a5518..7cf5f38d9a4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -25,18 +25,18 @@ /** * Computes best sensitivity where sensitivity is >= specified value. * - *

Sensitivity measures the proportion of actual positives that are correctly - * identified as such (tp / (tp + fn)). + *

{@code Sensitivity} measures the proportion of actual positives that are correctly + * identified as such {@code (tp / (tp + fn))}. * - *

Specificity measures the proportion of actual negatives that are correctly - * identified as such (tn / (tn + fp)). + *

{@code Specificity} measures the proportion of actual negatives that are correctly + * identified as such {@code (tn / (tn + fp))}. * - *

This metric creates four local variables, truePositives, trueNegatives - * , falsePositives and falseNegatives that are used to compute the + *

This metric creates four local variables, {@code truePositives}, {@code trueNegatives + * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the * sensitivity at the given specificity. The threshold for the given specificity value is computed * and used to evaluate the corresponding sensitivity. * - *

If sampleWeights is null, weights default to 1. Use sample_weight of + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of * 0 to mask values. * * @see Additional information diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 6dfdab48578..5294f798044 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -31,11 +31,11 @@ /** * Calculates how often predictions matches integer labels. * - *

You can provide logits of classes as predictions, since argmax of logits and + *

You can provide logits of classes as {@code predictions}, since argmax of logits and * probabilities are same. * *

This metric creates two local variables, `total` and `count` that are used to compute the - * frequency with which predictions matches labels. This frequency is + * frequency with which {@code predictions} matches {@code labels}. This frequency is * ultimately returned as `sparse categorical accuracy`: an idempotent operation that simply divides * `total` by `count`. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index d0b797690bd..981171f2221 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -23,19 +23,19 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes best specificity where sensitivity is >= specified value. Sensitivity - * measures the proportion of actual positives that are correctly identified as such - * (tp / (tp + fn)). + * Computes best specificity where sensitivity is >= specified value. {@code Sensitivity} + * measures the proportion of actual positives that are correctly identified as such {@code + * (tp / (tp + fn))}. * - *

Specificity measures the proportion of actual negatives that are correctly - * identified as such (tn / (tn + fp)). + *

{@code Specificity} measures the proportion of actual negatives that are correctly + * identified as such {@code (tn / (tn + fp))}. * - *

This metric creates four local variables, truePositives, trueNegatives - * , falsePositives and falseNegatives that are used to compute the + *

This metric creates four local variables, {@code truePositives}, {@code trueNegatives + * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the * specificity at the given sensitivity. The threshold for the given sensitivity value is computed * and used to evaluate the corresponding specificity. * - *

If sampleWeights is null, weights default to 1. Use sample_weight of + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of * 0 to mask values. * * @see Additional information diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java index a3241221b66..637ca6cdd05 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -21,10 +21,10 @@ /** * Computes the (weighted) sum of the given values. * - *

For example, if values is [1, 3, 5, 7] then the sum is 16. If the - * weights were specified as [1, 1, 0, 0], then the sum would be 4. + *

For example, if values is {@code [1, 3, 5, 7]} then the sum is {@code 16}. If the + * weights were specified as {@code [1, 1, 0, 0]}, then the sum would be {@code 4.} * - *

This metric creates one variable, total, that is used to compute the sum of + *

This metric creates one variable, {@code total}, that is used to compute the sum of * values. This is ultimately returned as sum. * *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index 9c8d6403a6b..0146552433f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -34,7 +34,7 @@ public class TopKCategoricalAccuracy extends MeanMetricWrappe private final int k; /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of * top elements to look at for computing accuracy. * * @param tf the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java index 91b6751588a..5c65f8c469f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true negatives. * - *

If sampleWeights is given, calculates the sum of the weights of true negatives. - * This metric creates one local variable, accumulator that is used to keep track of + *

If {@code sampleWeights} is given, calculates the sum of the weights of true negatives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of * the number of true negatives. * - *

If sampleWeights is null, weights default to 1. Use - * sampleWeights of 0 to mask values. + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. * * @param The data type for the metric result */ @@ -50,9 +50,9 @@ public TrueNegatives(Ops tf, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -66,9 +66,9 @@ public TrueNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -96,9 +96,9 @@ public TrueNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -113,9 +113,9 @@ public TrueNegatives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java index b67d381a62d..f0dd8c42de5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true positives. * - *

If sampleWeights is given, calculates the sum of the weights of true positives. - * This metric creates one local variable, accumulator that is used to keep track of + *

If {@code sampleWeights} is given, calculates the sum of the weights of true positives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of * the number of true positives. * - *

If sampleWeights is null, weights default to 1. Use - * sampleWeights of 0 to mask values. + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. * * @param The data type for the metric result */ @@ -50,9 +50,9 @@ public TruePositives(Ops tf, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -66,9 +66,9 @@ public TruePositives(Ops tf, float threshold, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -96,9 +96,9 @@ public TruePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -113,9 +113,9 @@ public TruePositives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index 63ea35df7f2..88597cf85ec 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -67,9 +67,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param threshold a threshold value in [0, 1]. A threshold is compared with + * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is - * true, below is false). One metric value is generated for each + * {@code true}, below is {@code false}). One metric value is generated for each * threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -91,9 +91,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param thresholds threshold values in [0, 1]. A threshold is compared with + * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is - * true, below is false). One metric value is generated for each + * {@code true}, below is {@code false}). One metric value is generated for each * threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 1fb3d3bb580..f89047e457d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -25,7 +25,7 @@ public interface LossMetric { /** - * Calculates the weighted loss between labels and predictions + * Calculates the weighted loss between {@code labels} and {@code predictions} * * @param labels the truth values or labels * @param predictions the predictions diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 9a532a0294f..37bdd5849ae 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -29,8 +29,8 @@ * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. * - *

The loss function calculates the loss between the labels and predictions - * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the + *

The loss function calculates the loss between the {@code labels} and {@code predictions + * } then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index f36aaa34d8f..54b2646a62b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -58,8 +58,8 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the sampleWeights can be broadcast to the same shape as values - * + * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values + * } * *

In losses and metrics, limited weight broadcasting is supported. Weights must be either * scalar, or the same rank as the target values, with each dimension either 1, or the same as the @@ -68,11 +68,11 @@ public class MetricsHelper { * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return Operation with control dependencies to ensure sampleWeight - * can be broadcast to values + * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} + * can be broadcast to {@code values} * @param the type of Operand - * @throws NotBroadcastableException If static checks determine sampleWeights has an - * incorrect shape that prohibit broadcasting to values + * @throws NotBroadcastableException If static checks determine {@code sampleWeights} has an + * incorrect shape that prohibit broadcasting to {@code values} */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -200,13 +200,13 @@ private static Operand canBroadcastDims( } /** - * Broadcast weights to the same shape as values. + * Broadcast {@code weights} to the same shape as {@code values}. * * @param tf the TensorFlow ops - * @param weights Operand whose shape is broadcastable to values. + * @param weights Operand whose shape is broadcastable to {@code values}. * @param values Operand of any shape * @param the type of Operands - * @return weights broadcast to values shape + * @return {@code weights} broadcast to {@code values} shape */ public static Operand broadcastWeights( Ops tf, Operand weights, Operand values) { @@ -291,13 +291,13 @@ public static List assertShapes( /** * Returns an op to update the given confusion matrix variables. * - *

For every pair of values in labels and predictions: + *

For every pair of values in {@code labels} and {@code predictions}: * *

-   * TRUE_POSITIVES:  labels == true and predictions > thresholds
-   * FALSE_POSITIVES: labels == true and predictions <= thresholds
-   * TRUE_NEGATIVES:  labels == false and predictions <= thresholds
-   * FALSE_NEGATIVE:  labels == false and predictions > thresholds
+   * TRUE_POSITIVES:  {@code labels} == true and {@code predictions} > thresholds
+   * FALSE_POSITIVES: {@code labels} == true and {@code predictions} <= thresholds
+   * TRUE_NEGATIVES:  {@code labels} == false and {@code predictions} <= thresholds
+   * FALSE_NEGATIVE:  {@code labels} == false and {@code predictions} > thresholds
    * 
* *

The results will be weighted and added together. When multiple thresholds are provided, we @@ -306,46 +306,45 @@ public static List assertShapes( *

For estimation of these metrics over a stream of data, the function creates an `update_op` * operation that updates the given variables. * - *

labels, predictions, and sampleWeight tensors are - * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. - * sampleWeight is then broadcast to the shape of predictions. + *

{@code labels}, {@code predictions}, and {@code sampleWeight} tensors are + * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code + * sampleWeight} is then broadcast to the shape of {@code predictions}. * * @param tf the TensorFlow Ops * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. If multiLabel, then the variable + * corresponding variables to update as values. If {@code multiLabel}, then the variable * shapes are (T, D), where T is the number of thresholds and D is the number of classes - * (after slicing by classIndex, if provided). If multiLabels, then + * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then * the variable shapes are (T). * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding initializer Operands to for variablesToUpdate. + * corresponding initializer Operands to for {@code variablesToUpdate}. * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra * dimension of size 1 that would be squeezed. * @param predictions the predictions shape (N, Cx, P1?) - * @param thresholds thresholds in the range [0, 1], or {@link #NEG_INF} is used when + * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when * topK is set * @param topK optional, indicates that only the top k predictions should be considered. Applied - * before possibly slicing by classIndex. + * before possibly slicing by {@code classIndex}. * @param classIndex optional, limits the prediction and labels to the specified class. This is an * integer index into the first dimension of Cx. - * @param sampleWeight optional Tensor that is aligned with labels and predictions as + * @param sampleWeight optional {@code Tensor} that is aligned with labels and predictions as * explained above. Use weights of 0 to mask values. * @param multiLabel indicates whether multidimensional prediction/labels should be treated as - * multilabel responses, or flattened into a single label. When true, the values of - * variablesToUpdate must have a second dimension equal to the number of labels and + * multilabel responses, or flattened into a single label. When true, the values of {@code + * variablesToUpdate} must have a second dimension equal to the number of labels and * predictions per example, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES * without explicit multilabel handling (i.e. when the data is to be flattened). Must have - * shape (Dx), which is the same as (Cx) referenced above, except that if classIndex - * is provided, then the final dimension of Dx is 1. These weights will be broadcast - * across the 0th dimension (the examples dimension) of predictions. May be null. - * Must be null if multiLabel. + * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex + * } is provided, then the final dimension of Dx is 1. These weights will be broadcast + * across the 0th dimension (the examples dimension) of {@code predictions}. May be null. + * Must be null if {@code multiLabel}. * @param the data type for the variables - * @throws IllegalArgumentException If predictions and labels have - * mismatched shapes, or if sampleWeight is not nulland its shape - * doesn't match predictions, or if multiLabel && labelWeights != null - * . + * @throws IllegalArgumentException If {@code predictions} and {@code labels} have + * mismatched shapes, or if {@code sampleWeight} is not null and its shape + * doesn't match {@code predictions}, or if {@code multiLabel && labelWeights != null}.. * @return an op to update the given confusion matrix variables. */ @SuppressWarnings({"unchecked", "rawtypes"}) @@ -689,8 +688,8 @@ private static Operand filterTopK(Ops tf, Operand x, i // alias for mean /** - * Calculate the mean of the operand, along all axes and keepDims is false - * + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false + * } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -702,8 +701,8 @@ public static Operand mean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is + * {@code false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -721,12 +720,12 @@ public static Operand mean( * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. * @param the type of the operand - * @return the mean of elements of x. + * @return the mean of elements of {@code x}. */ public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); @@ -738,12 +737,12 @@ public static Operand mean(Ops tf, Operand x, boolean * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. * @param the data type of the Operand - * @return the mean of elements of x. + * @return the mean of elements of {@code x}. */ public static Operand mean( Ops tf, Operand x, Operand axes, boolean keepDims) { @@ -778,16 +777,16 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( * *

For example: * - *

+   * 
{@code
    *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
    *          [[0 0 0 0 0]
    *           [0 0 1 0 0]
    *           [0 0 1 0 0]
    *           [0 0 0 0 0]
    *           [0 0 0 0 1]]
-   * 
+ * }
* - * Note that the possible labels are assumed to be {@copde [0, 1, 2, 3,4]}, resulting in a 5x5 + * Note that the possible labels are assumed to be {@code [0, 1, 2, 3,4]}, resulting in a 5x5 * confusion matrix. * * @param tf the TensorFlow Ops @@ -798,12 +797,12 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( * @param weights optional weights to be applied to the confusion matrix * @param type Data type of the confusion matrix. * @param the type of Operands - * @return A Operand of type type with shape [n, n] - * representing the confusion matrix, where n is the number of possible labels in + * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} + * representing the confusion matrix, where {@code n} is the number of possible labels in * the classification task. - * @throws IllegalArgumentException If both predictions and labels do - * not have compatible shapes, or if weights is notnull and its - * shape is not compatible with predictions. + * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do + * not have compatible shapes, or if {@code weights} is not{@code null} and its + * shape is not compatible with {@code predictions}. */ // TODO should this be moved to FramnworkOps under math. public static Operand confusionMatrix( @@ -879,8 +878,8 @@ public static Operand confusionMatrix( } /** - * Calculate the mean of the operand, along all axes and keepDims is false - * + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false + * } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -891,8 +890,8 @@ public static Operand booleanMean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is + * {@code false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -909,11 +908,11 @@ public static Operand booleanMean( * * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. - * @return the mean of elements of x containing floating point numbers + * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); @@ -925,11 +924,11 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. - * @return the mean of elements of x containing floating point numbers + * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean( Ops tf, Operand x, Operand axes, boolean keepDims) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 3b54ad2e08d..b96d2dfa1d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -106,8 +106,8 @@ public Op resetStates() { } /** - * Updates the metric variables based on the inputs. At least one input arg required for {@}code - * values}, an optional additional input for the sampleWeights + * Updates the metric variables based on the inputs. At least one input arg required for {@code + * values}, an optional additional input for the {@code sampleWeights} * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, will default to 1 if null. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 467dea19b57..68157632557 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -26,16 +26,16 @@ public class SetsOps { /** - * Computes set difference of elements in last dimension of a and b with - * aMinusB set to true. + * Computes set difference of elements in last dimension of {@code a} and {@code b} with + * {@code aMinusB} set to true. * - *

All but the last dimension of a and b must match + *

All but the last dimension of {@code a} and {@code b} must match * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -44,16 +44,16 @@ public static Operand difference(Ops tf, Operand a, Op } /** - * Computes set difference of elements in last dimension of a and b. + * Computes set difference of elements in last dimension of {@code a} and {@code b}. * - *

All but the last dimension of a and b must match + *

All but the last dimension of {@code a} and {@code b} must match * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param aMinusB whether to subtract b from a, vs vice versa. * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -63,13 +63,13 @@ public static Operand difference( } /** - * Computes set union of elements in last dimension of a and b. + * Computes set union of elements in last dimension of {@code a} and {@code b}. * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -78,13 +78,13 @@ public static Operand union(Ops tf, Operand a, Operand } /** - * Computes set intersection of elements in last dimension of a and b. + * Computes set intersection of elements in last dimension of {@code a} and {@code b}. * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -93,14 +93,14 @@ public static Operand intersection(Ops tf, Operand a, } /** - * Compute set operation of elements in last dimension of a and b. + * Compute set operation of elements in last dimension of {@code a} and {@code b}. * * @param tf the TensorFlow Ops * @param a The first set operation operand * @param b The other et operation operand * @param setOperation The set operation to perform, {@link Operation}. * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the same. Elements along the last dimension contain the results of the set * operation. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 36792b8ea7a..fc7f1abbd89 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -151,17 +151,17 @@ private static Operand hasValidDims( /** * Broadcast `weights` to the same shape as `values`. * - *

This returns a version of weights following the same broadcast rules as + *

This returns a version of {@code weights} following the same broadcast rules as {@code * mul(weights, - * values), but limited to the weights shapes allowed by assertBroadcastable - * When computing a weighted average, use this function to broadcast weights before - * summing them; e.g., reduceSum(w * v) / reduceSum(_broadcast_weights(w, v)). + * values)}, but limited to the weights shapes allowed by {@code assertBroadcastable} + * When computing a weighted average, use this function to broadcast {@code weights} before + * summing them; e.g., {@code reduceSum(w * v) / reduceSum(_broadcast_weights(w, v))}. * * @param tf the TensorFlow ops * @param weights `Tensor` whose shape is able to be broadcast to `values` * @param values Tensor` of any shape * @param the type of Operand - * @return weights broadcast to values shape + * @return {@code weights} broadcast to {@code values} shape */ public static Operand broadcastWeights( Ops tf, Operand weights, Operand values) { From 088e0d8b1ab53da982dd863498c7958a5eca42e5 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 11:16:37 -0500 Subject: [PATCH 44/97] Simplify generic parameters across losses and metrics. --- .../framework/losses/CategoricalHinge.java | 4 +-- .../tensorflow/framework/losses/Hinge.java | 19 ++++++------ .../tensorflow/framework/losses/Huber.java | 6 ++-- .../framework/losses/KLDivergence.java | 2 +- .../tensorflow/framework/losses/LogCosh.java | 5 ++-- .../org/tensorflow/framework/losses/Loss.java | 5 ++-- .../tensorflow/framework/losses/Losses.java | 29 ++++++------------- .../framework/losses/MeanAbsoluteError.java | 2 +- .../losses/MeanAbsolutePercentageError.java | 2 +- .../framework/losses/MeanSquaredError.java | 2 +- .../losses/MeanSquaredLogarithmicError.java | 2 +- .../framework/metrics/BinaryCrossentropy.java | 14 ++++----- .../metrics/CategoricalCrossentropy.java | 9 ++---- .../framework/metrics/CategoricalHinge.java | 11 ++----- .../framework/metrics/CosineSimilarity.java | 14 +++------ .../tensorflow/framework/metrics/Hinge.java | 12 +++----- .../framework/metrics/KLDivergence.java | 12 +++----- .../framework/metrics/LogCoshError.java | 12 +++----- .../framework/metrics/MeanAbsoluteError.java | 11 ++----- .../metrics/MeanAbsolutePercentageError.java | 13 +++------ .../framework/metrics/MeanSquaredError.java | 11 ++----- .../metrics/MeanSquaredLogarithmicError.java | 13 +++------ .../tensorflow/framework/metrics/Poisson.java | 13 ++++----- .../SparseCategoricalCrossentropy.java | 18 ++++-------- .../framework/metrics/SquaredHinge.java | 12 +++----- .../framework/metrics/impl/LossMetric.java | 2 +- .../framework/metrics/HingeTest.java | 6 ++-- .../framework/metrics/PoissonTest.java | 3 +- 28 files changed, 94 insertions(+), 170 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 73837ed1756..4e9133d8835 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -25,7 +25,7 @@ *

loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions) * and pos=sum(labels*predictions) * - *

labels values are expected to be 0 or 1. + *

labels values are expected to be 0 or 1.

* *

Standalone usage: * @@ -100,7 +100,7 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index d4c350ef06c..37e7e367b9b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -18,16 +18,15 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; - import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the hinge loss between labels and predictions. * - *

loss = maximum(1 - labels * predictions, 0). + *

loss = maximum(1 - labels * predictions, 0)

. * - *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, - * they will be converted to -1 or 1. + *

labels values are expected to be -1 or 1. + * If binary (0 or 1) labels are provided, they will be converted to -1 or 1.

* *

Standalone usage: * @@ -107,7 +106,7 @@ public Hinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor @@ -117,19 +116,21 @@ public Hinge(Ops tf, String name, Reduction reduction) { * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. + * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - Operand tLabels = cast(tf, labels, predictions.type()); - tLabels = - LossesHelper.valueCheck( + @SuppressWarnings("unchecked") + Operand tLabels = predictions.type() == labels.type() ? + (Operand)labels : cast(tf, labels, predictions.type()); + tLabels = LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index b1aee1b0656..e8de632eb09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -89,7 +89,6 @@ public Huber(Ops tf) { * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops - * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Huber(Ops tf, String name) { this(tf, name, DELTA_DEFAULT, Reduction.AUTO); @@ -110,7 +109,6 @@ public Huber(Ops tf, Reduction reduction) { * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta * * @param tf the TensorFlow Ops - * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Huber(Ops tf, String name, Reduction reduction) { @@ -121,7 +119,7 @@ public Huber(Ops tf, String name, Reduction reduction) { * Creates a Huber Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. + * @param name the name of the loss * @param delta the point where the Huber loss function changes from quadratic to linear. * @param reduction Type of Reduction to apply to the loss. */ @@ -133,7 +131,7 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index 2aa1f72092b..b3c0206b409 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -100,7 +100,7 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index a11d582e527..812260d9881 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -77,7 +77,6 @@ public LogCosh(Ops tf) { * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops - * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public LogCosh(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -97,7 +96,7 @@ public LogCosh(Ops tf, Reduction reduction) { * Creates a LogCosh Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. + * @param name the name of the loss * @param reduction Type of Reduction to apply to the loss. */ public LogCosh(Ops tf, String name, Reduction reduction) { @@ -107,7 +106,7 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index cdd35d28aba..0f9b183f38c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -25,7 +25,7 @@ public abstract class Loss { protected final Reduction reduction; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops @@ -64,8 +64,7 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param The data type of the predictions and loss. * @return the loss */ - public Operand call( - Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return call(labels, predictions, null); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 9aa94cf7fcf..a5ced3d1df8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -102,10 +102,8 @@ public static Operand meanAbsolutePercentageError( tf.math.abs( tf.math.div( tf.math.sub(tLabels, predictions), - tf.math.maximum( - tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); - return tf.math.mul( - cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); + tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); + return tf.math.mul(cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -151,11 +149,7 @@ public static Operand meanSquaredLogarithmicError( * @return the binary crossentropy loss. */ public static Operand binaryCrossentropy( - Ops tf, - Operand labels, - Operand predictions, - boolean fromLogits, - float labelSmoothing) { + Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -220,10 +214,9 @@ private static Operand binaryCrossentropyHelper( * @param labels true targets * @param predictions the predictions * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are - * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 - * means that we will use a value of 0.1 for label 0 and - * 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the + * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a + * value of 0.1 for label 0 and 0.9 for label 1 * @param axis the * @param the data type of the predictions and labels * @return the categorical crossentropy loss. @@ -510,11 +503,7 @@ public static Operand poisson( * @return the sparse categorical crossentropy loss */ public static Operand sparseCategoricalCrossentropy( - Ops tf, - Operand labels, - Operand predictions, - boolean fromLogits, - int axis) { + Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -655,14 +644,14 @@ private static Operand smoothCategoricalLabels( * @param tf The TensorFlow Ops * @param x the input * @param axis Dimension along which to normalize. - * @param the data type for the input and the result * @return the normalized values based on L2 norm */ public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = - tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); + tf.math.rsqrt( + tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); return tf.math.mul(x, invNorm); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 03a3cf70110..594de1e1448 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -96,7 +96,7 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 6c5242df4f2..275a2e136a0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -96,7 +96,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index f975db55c44..31df3e70e0b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -96,7 +96,7 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index 11b8e157e90..bef990d22bc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -96,7 +96,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 48ee244eafb..abd2dcbbf40 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -31,8 +31,8 @@ * * @param The data type for the metric result */ -public class BinaryCrossentropy extends MeanMetricWrapper - implements LossMetric { +public class BinaryCrossentropy + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -42,8 +42,7 @@ public class BinaryCrossentropy extends MeanMetricWrapper * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a - * probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing @@ -62,10 +61,7 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); + public Operand call(Operand labels, Operand predictions) { + return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index b22e5415f79..be43f34b92e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -21,8 +21,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Metric that computes the categorical cross-entropy loss between true labels and predicted * labels. @@ -101,11 +99,8 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalCrossentropy( - getTF(), tLabels, tPredictions, fromLogits, labelSmoothing, axis); + getTF(), labels, predictions, fromLogits, labelSmoothing, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 4266cc487c0..c70f2d8643b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -21,14 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * * @param The data type for the metric result */ -public class CategoricalHinge extends MeanMetricWrapper +public class CategoricalHinge< T extends TNumber> extends MeanMetricWrapper implements LossMetric { /** @@ -47,10 +45,7 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.categoricalHinge(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.categoricalHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 840f255c5ab..5abbd095420 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -21,14 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the cosine similarity metric between labels and predictions. * * @param The data type for the metric result. */ -public class CosineSimilarity extends MeanMetricWrapper +public class CosineSimilarity< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; @@ -78,12 +76,8 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, - // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.cosineSimilarity(getTF(), tLabels, tPredictions, axis); + public Operand call(Operand labels, Operand predictions) { + // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity + return Losses.cosineSimilarity(getTF(), labels, predictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 46ccd2859ff..e0aced6fa3e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -21,14 +21,13 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the hinge loss metric between labels and predictions. * * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper implements LossMetric { +public class Hinge extends MeanMetricWrapper + implements LossMetric { /** * Creates a Hinge metric @@ -46,10 +45,7 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.hinge(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.hinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 9ffcd6189f1..fa09f2784b5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -21,15 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. * * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper implements LossMetric { +public class KLDivergence extends MeanMetricWrapper + implements LossMetric { /** * Creates a KLDivergence metric @@ -47,10 +46,7 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 59e24f57110..c43551a6948 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -21,15 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. * * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper implements LossMetric { +public class LogCoshError extends MeanMetricWrapper< T> + implements LossMetric { /** * Creates a LogCoshError metric @@ -47,10 +46,7 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.logCosh(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.logCosh(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 1cc6d0b6f99..d343ec77ab0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -21,14 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the mean of absolute difference between labels and predictions. * * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper +public class MeanAbsoluteError extends MeanMetricWrapper< T> implements LossMetric { /** @@ -47,10 +45,7 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.meanAbsoluteError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 8c6720b58f6..dd7d151260b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -21,15 +21,13 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the mean of absolute difference between labels and predictions. * * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError extends MeanMetricWrapper - implements LossMetric { +public class MeanAbsolutePercentageError + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -47,10 +45,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 3c4c79d39ba..c2bef576b30 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -21,14 +21,12 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the mean of absolute difference between labels and predictions. * * @param The data type for the metric result. */ -public class MeanSquaredError extends MeanMetricWrapper +public class MeanSquaredError< T extends TNumber> extends MeanMetricWrapper< T> implements LossMetric { /** @@ -47,10 +45,7 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.meanSquaredError(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.meanSquaredError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index d525bb76648..c1cf4ca6c9a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -21,15 +21,13 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the mean of absolute difference between labels and predictions. * * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError extends MeanMetricWrapper - implements LossMetric { +public class MeanSquaredLogarithmicError + extends MeanMetricWrapper implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -47,10 +45,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 422fd4808ff..af50b103a60 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -21,14 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the poisson loss metric between labels and predictions. * + * @param The data type for the metric result. */ -public class Poisson extends MeanMetricWrapper implements LossMetric { +public class Poisson< T extends TNumber> extends MeanMetricWrapper< T> + implements LossMetric { /** * Creates a Poisson metric @@ -46,10 +46,7 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.poisson(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.poisson(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index e954169b2af..a0c016b70b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -21,16 +21,14 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. - * + *\ * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy extends MeanMetricWrapper - implements LossMetric { +public class SparseCategoricalCrossentropy + extends MeanMetricWrapper implements LossMetric { private final boolean fromLogits; private final int axis; @@ -40,8 +38,7 @@ public class SparseCategoricalCrossentropy extends MeanMetric * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a - * probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. * @param axis The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -57,10 +54,7 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); + public Operand call(Operand labels, Operand predictions) { + return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 19b3b1d0ac4..bd331a85eda 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -21,14 +21,13 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A metric that computes the squared hinge loss metric between labels and predictions. * * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper implements LossMetric { +public class SquaredHinge extends MeanMetricWrapper + implements LossMetric { /** * Creates a SquaredHinge metric @@ -46,10 +45,7 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call( - Operand labels, Operand predictions) { - Operand tLabels = cast(getTF(), labels, getResultType()); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - return Losses.squaredHinge(getTF(), tLabels, tPredictions); + public Operand call(Operand labels, Operand predictions) { + return Losses.squaredHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 1fb3d3bb580..70bb8133698 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -31,5 +31,5 @@ public interface LossMetric { * @param predictions the predictions * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index 90531d21fde..a9bd5fac76e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,7 +32,8 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); + Hinge instance = + new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; @@ -54,7 +55,8 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); + Hinge instance = + new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { -1, 1, -1, 1, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index 5631bac15ee..75d9ef93168 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -55,7 +55,8 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); + Poisson instance = + new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; float[] predArray = {1, 9, 2, 5, 2, 6}; From 4add0287830957c1af133a55982b8affeb2c2dad Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 14:42:48 -0500 Subject: [PATCH 45/97] Reformat code --- .../annotations/org/tensorflow/op/Ops.java | 6 ++-- .../framework/losses/CategoricalHinge.java | 4 +-- .../tensorflow/framework/losses/Hinge.java | 20 +++++++------ .../tensorflow/framework/losses/Huber.java | 2 +- .../framework/losses/KLDivergence.java | 2 +- .../tensorflow/framework/losses/LogCosh.java | 2 +- .../org/tensorflow/framework/losses/Loss.java | 5 ++-- .../tensorflow/framework/losses/Losses.java | 28 +++++++++++++------ .../framework/losses/MeanAbsoluteError.java | 2 +- .../losses/MeanAbsolutePercentageError.java | 2 +- .../framework/losses/MeanSquaredError.java | 2 +- .../losses/MeanSquaredLogarithmicError.java | 2 +- .../framework/losses/impl/LossesHelper.java | 7 +---- .../framework/metrics/BinaryCrossentropy.java | 9 +++--- .../framework/metrics/CategoricalHinge.java | 4 +-- .../framework/metrics/CosineSimilarity.java | 2 +- .../tensorflow/framework/metrics/Hinge.java | 3 +- .../framework/metrics/KLDivergence.java | 5 ++-- .../framework/metrics/LogCoshError.java | 3 +- .../framework/metrics/MeanAbsoluteError.java | 2 +- .../metrics/MeanAbsolutePercentageError.java | 6 ++-- .../framework/metrics/MeanSquaredError.java | 4 +-- .../metrics/MeanSquaredLogarithmicError.java | 6 ++-- .../tensorflow/framework/metrics/Poisson.java | 6 ++-- .../SparseCategoricalCrossentropy.java | 13 +++++---- .../framework/metrics/SquaredHinge.java | 5 ++-- .../framework/metrics/impl/LossMetric.java | 2 +- .../framework/metrics/HingeTest.java | 6 ++-- .../framework/metrics/PoissonTest.java | 3 +- 29 files changed, 82 insertions(+), 81 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 7d6d159f5ef..acbae4dac6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -364,10 +364,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -390,8 +390,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java index 4e9133d8835..73837ed1756 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java @@ -25,7 +25,7 @@ *

loss = maximum(neg - pos + 1, 0) where neg=maximum((1-labels)*predictions) * and pos=sum(labels*predictions) * - *

labels values are expected to be 0 or 1.

+ *

labels values are expected to be 0 or 1. * *

Standalone usage: * @@ -100,7 +100,7 @@ public CategoricalHinge(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.categoricalHinge(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index 37e7e367b9b..db3569441ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -18,15 +18,16 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; + import static org.tensorflow.framework.utils.CastHelper.cast; /** * Computes the hinge loss between labels and predictions. * - *

loss = maximum(1 - labels * predictions, 0)

. + *

loss = maximum(1 - labels * predictions, 0). * - *

labels values are expected to be -1 or 1. - * If binary (0 or 1) labels are provided, they will be converted to -1 or 1.

+ *

labels values are expected to be -1 or 1. If binary (0 or 1) labels are provided, + * they will be converted to -1 or 1. * *

Standalone usage: * @@ -106,7 +107,7 @@ public Hinge(Ops tf, String name, Reduction reduction) { * label values are not in the set [-1., 0., 1.]. * * @param labels the truth values or labels, must be either -1, 0, or 1. Values are expected to be - * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. + * -1 or 1. If binary (0 or 1) labels are provided they will be converted to -1 or 1. * @param predictions the predictions, values must be in the range [0. to 1.] inclusive. * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor @@ -124,13 +125,16 @@ public Hinge(Ops tf, String name, Reduction reduction) { public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { @SuppressWarnings("unchecked") - Operand tLabels = predictions.type() == labels.type() ? - (Operand)labels : cast(tf, labels, predictions.type()); - tLabels = LossesHelper.valueCheck( + Operand tLabels = + predictions.type() == labels.type() + ? (Operand) labels + : cast(tf, labels, predictions.type()); + tLabels = + LossesHelper.valueCheck( getTF(), "labels value check [-1, 0, 1]", tLabels, - cast(getTF(), getTF().constant(new int[] { -1, 0, 1}), predictions.type())); + cast(getTF(), getTF().constant(new int[] {-1, 0, 1}), predictions.type())); Operand losses = Losses.hinge(getTF(), tLabels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index e8de632eb09..665a9ac157d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -131,7 +131,7 @@ public Huber(Ops tf, String name, float delta, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.huber(getTF(), labels, predictions, delta); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java index b3c0206b409..2aa1f72092b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/KLDivergence.java @@ -100,7 +100,7 @@ public KLDivergence(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 812260d9881..78325713e3e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -106,7 +106,7 @@ public LogCosh(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.logCosh(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java index 0f9b183f38c..cdd35d28aba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Loss.java @@ -25,7 +25,7 @@ public abstract class Loss { protected final Reduction reduction; /** - * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops @@ -64,7 +64,8 @@ protected Loss(Ops tf, String name, Reduction reduction) { * @param The data type of the predictions and loss. * @return the loss */ - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { return call(labels, predictions, null); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index a5ced3d1df8..2222ebb41f8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -102,8 +102,10 @@ public static Operand meanAbsolutePercentageError( tf.math.abs( tf.math.div( tf.math.sub(tLabels, predictions), - tf.math.maximum(tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); - return tf.math.mul(cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); + tf.math.maximum( + tf.math.abs(tLabels), cast(tf, tf.constant(EPSILON), predictionType)))); + return tf.math.mul( + cast(tf, tf.constant(100), predictionType), tf.math.mean(diff, tf.constant(-1))); } /** @@ -149,7 +151,11 @@ public static Operand meanSquaredLogarithmicError( * @return the binary crossentropy loss. */ public static Operand binaryCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, float labelSmoothing) { + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + float labelSmoothing) { Operand tLabels = cast(tf, labels, predictions.type()); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = ops.getTarget(); @@ -214,9 +220,10 @@ private static Operand binaryCrossentropyHelper( * @param labels true targets * @param predictions the predictions * @param fromLogits Whether to interpret predictions as a tensor of logit values - * @param labelSmoothing Float in [0, 1]. When > 0, label values are smoothed, meaning the - * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a - * value of 0.1 for label 0 and 0.9 for label 1 + * @param labelSmoothing Float in [0, 1]. When > 0, label values are + * smoothed, meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and + * 0.9 for label 1 * @param axis the * @param the data type of the predictions and labels * @return the categorical crossentropy loss. @@ -503,7 +510,11 @@ public static Operand poisson( * @return the sparse categorical crossentropy loss */ public static Operand sparseCategoricalCrossentropy( - Ops tf, Operand labels, Operand predictions, boolean fromLogits, int axis) { + Ops tf, + Operand labels, + Operand predictions, + boolean fromLogits, + int axis) { Class predictionType = predictions.type(); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); @@ -650,8 +661,7 @@ public static Operand l2Normalize(Ops tf, Operand x, i Operand squareSum = tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); Operand invNorm = - tf.math.rsqrt( - tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); + tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); return tf.math.mul(x, invNorm); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java index 594de1e1448..03a3cf70110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsoluteError.java @@ -96,7 +96,7 @@ public MeanAbsoluteError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsoluteError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java index 275a2e136a0..6c5242df4f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanAbsolutePercentageError.java @@ -96,7 +96,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanAbsolutePercentageError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java index 31df3e70e0b..f975db55c44 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredError.java @@ -96,7 +96,7 @@ public MeanSquaredError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java index bef990d22bc..11b8e157e90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/MeanSquaredLogarithmicError.java @@ -96,7 +96,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, Reduction reduction) { /** {@inheritDoc} */ @Override public Operand call( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, Operand predictions, Operand sampleWeights) { Operand losses = Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); return LossesHelper.computeWeightedLoss(getTF(), losses, getReduction(), sampleWeights); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index f6b0de71b0d..66bdd839f09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -53,7 +53,6 @@ public class LossesHelper { * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . - * @param the data type for the labels, predictions and result * @return LossTuple of prediction, label,sampleWeight will * be null. Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, (prediction, @@ -82,7 +81,6 @@ public static LossTuple squeezeOrExpandDimensions( * @param sampleWeights Optional sample weight(s) Operand whose dimensions match * * prediction. - * @param the data type for the labels, predictions and result * @return LossTuple of predictions, labels and sampleWeight * . Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, only the possibly @@ -182,7 +180,6 @@ private static Operand maybeExpandWeights( * @param labels Label values, a Tensor whose dimensions match predictions * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. - * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -198,7 +195,6 @@ public static LossTuple removeSqueezableDimensions( * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). - * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -222,8 +218,7 @@ public static LossTuple removeSqueezableDimensions( } // Use dynamic rank. - // TODO: hold for lazy select feature, - // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + // TODO Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { /* * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index abd2dcbbf40..d8bb2a41116 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,8 +21,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * @@ -31,8 +29,8 @@ * * @param The data type for the metric result */ -public class BinaryCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class BinaryCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final float labelSmoothing; @@ -42,7 +40,8 @@ public class BinaryCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, * compute the loss between the predicted labels and a smoothed version of the true labels, * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index c70f2d8643b..4800fc43c49 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result */ -public class CategoricalHinge< T extends TNumber> extends MeanMetricWrapper +public class CategoricalHinge extends MeanMetricWrapper implements LossMetric { /** @@ -45,7 +45,7 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.categoricalHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 5abbd095420..3ae67072955 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class CosineSimilarity< T extends TNumber> extends MeanMetricWrapper< T> +public class CosineSimilarity extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_AXIS = -1; private final int[] axis; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index e0aced6fa3e..3b84b81e071 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -26,8 +26,7 @@ * * @param The data type for the metric result. */ -public class Hinge extends MeanMetricWrapper - implements LossMetric { +public class Hinge extends MeanMetricWrapper implements LossMetric { /** * Creates a Hinge metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index fa09f2784b5..f631f562e1d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -27,8 +27,7 @@ * * @param The data type for the metric result. */ -public class KLDivergence extends MeanMetricWrapper - implements LossMetric { +public class KLDivergence extends MeanMetricWrapper implements LossMetric { /** * Creates a KLDivergence metric @@ -46,7 +45,7 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index c43551a6948..046937e228b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -27,8 +27,7 @@ * * @param The data type for the metric result. */ -public class LogCoshError extends MeanMetricWrapper< T> - implements LossMetric { +public class LogCoshError extends MeanMetricWrapper implements LossMetric { /** * Creates a LogCoshError metric diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index d343ec77ab0..977f61648a1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class MeanAbsoluteError extends MeanMetricWrapper< T> +public class MeanAbsoluteError extends MeanMetricWrapper implements LossMetric { /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index dd7d151260b..bad5255969a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -26,8 +26,8 @@ * * @param The data type for the metric result. */ -public class MeanAbsolutePercentageError - extends MeanMetricWrapper implements LossMetric { +public class MeanAbsolutePercentageError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -45,7 +45,7 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index c2bef576b30..5b0d9ec43b3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -26,7 +26,7 @@ * * @param The data type for the metric result. */ -public class MeanSquaredError< T extends TNumber> extends MeanMetricWrapper< T> +public class MeanSquaredError extends MeanMetricWrapper implements LossMetric { /** @@ -45,7 +45,7 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index c1cf4ca6c9a..35044fee956 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -26,8 +26,8 @@ * * @param The data type for the metric result. */ -public class MeanSquaredLogarithmicError - extends MeanMetricWrapper implements LossMetric { +public class MeanSquaredLogarithmicError extends MeanMetricWrapper + implements LossMetric { /** * Creates a Mean Absolute Error metric @@ -45,7 +45,7 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index af50b103a60..700099d3375 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -24,11 +24,9 @@ /** * A metric that computes the poisson loss metric between labels and predictions. * - * @param The data type for the metric result. */ -public class Poisson< T extends TNumber> extends MeanMetricWrapper< T> - implements LossMetric { +public class Poisson extends MeanMetricWrapper implements LossMetric { /** * Creates a Poisson metric @@ -46,7 +44,7 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.poisson(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index a0c016b70b3..aa7ca316378 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -23,12 +23,12 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and - * predicted labels. - *\ + * predicted labels. \ + * * @param The data type for the metric result. */ -public class SparseCategoricalCrossentropy - extends MeanMetricWrapper implements LossMetric { +public class SparseCategoricalCrossentropy extends MeanMetricWrapper + implements LossMetric { private final boolean fromLogits; private final int axis; @@ -38,7 +38,8 @@ public class SparseCategoricalCrossentropy * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. - * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a + * probability distribution. * @param axis The dimension along which the entropy is computed. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -54,7 +55,7 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index bd331a85eda..01f4a403f84 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -26,8 +26,7 @@ * * @param The data type for the metric result. */ -public class SquaredHinge extends MeanMetricWrapper - implements LossMetric { +public class SquaredHinge extends MeanMetricWrapper implements LossMetric { /** * Creates a SquaredHinge metric @@ -45,7 +44,7 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { return Losses.squaredHinge(getTF(), labels, predictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 70bb8133698..037d634cd4a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -31,5 +31,5 @@ public interface LossMetric { * @param predictions the predictions * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java index a9bd5fac76e..90531d21fde 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -32,8 +32,7 @@ class HingeTest { public void testUnweighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; @@ -55,8 +54,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Hinge instance = - new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); + Hinge instance = new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); session.run(instance.resetStates()); int[] trueArray = { -1, 1, -1, 1, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java index 75d9ef93168..5631bac15ee 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -55,8 +55,7 @@ public void testUnweighted() { public void testWeighted() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Poisson instance = - new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); + Poisson instance = new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {4, 8, 12, 8, 1, 3}; float[] predArray = {1, 9, 2, 5, 2, 6}; From c23163caa84c960bf44deda55eccbb65ab2f8e05 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 1 Feb 2021 18:40:00 -0500 Subject: [PATCH 46/97] Change order of TrainOps and QuantiQuantizationOps. For some reason, when I build it reverses these 2 from master's version. --- .../src/gen/annotations/org/tensorflow/op/Ops.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index acbae4dac6b..7d6d159f5ef 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -364,10 +364,10 @@ public final class Ops { public final SignalOps signal; - public final QuantizationOps quantization; - public final TrainOps train; + public final QuantizationOps quantization; + private final Scope scope; private Ops(Scope scope) { @@ -390,8 +390,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - quantization = new QuantizationOps(this); train = new TrainOps(this); + quantization = new QuantizationOps(this); } /** From 7bd1fcf197942fb279e8956455848b357f188cc9 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Feb 2021 15:32:07 -0500 Subject: [PATCH 47/97] Fix LossMetric to change abstract "call" method to use gneric parameter for predictions instead of . --- .../framework/metrics/BinaryCrossentropy.java | 8 ++++++-- .../framework/metrics/CategoricalCrossentropy.java | 8 ++++++-- .../framework/metrics/CategoricalHinge.java | 8 ++++++-- .../framework/metrics/CosineSimilarity.java | 11 ++++++++--- .../java/org/tensorflow/framework/metrics/Hinge.java | 8 ++++++-- .../tensorflow/framework/metrics/KLDivergence.java | 8 ++++++-- .../tensorflow/framework/metrics/LogCoshError.java | 8 ++++++-- .../framework/metrics/MeanAbsoluteError.java | 8 ++++++-- .../metrics/MeanAbsolutePercentageError.java | 8 ++++++-- .../framework/metrics/MeanSquaredError.java | 8 ++++++-- .../metrics/MeanSquaredLogarithmicError.java | 8 ++++++-- .../org/tensorflow/framework/metrics/Poisson.java | 8 ++++++-- .../metrics/SparseCategoricalCrossentropy.java | 8 ++++++-- .../tensorflow/framework/metrics/SquaredHinge.java | 8 ++++++-- .../tensorflow/framework/metrics/impl/LossMetric.java | 2 +- 15 files changed, 87 insertions(+), 30 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index d8bb2a41116..263b8a789ed 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. * @@ -60,7 +62,9 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index be43f34b92e..cbe0127295f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical cross-entropy loss between true labels and predicted * labels. @@ -99,8 +101,10 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalCrossentropy( - getTF(), labels, predictions, fromLogits, labelSmoothing, axis); + getTF(), tLabels, tPredictions, fromLogits, labelSmoothing, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 4800fc43c49..ff814ae6ed3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A Metric that computes the categorical hinge loss metric between labels and predictions. * @@ -45,7 +47,9 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.categoricalHinge(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.categoricalHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 3ae67072955..d64136d0d90 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the cosine similarity metric between labels and predictions. * @@ -76,8 +78,11 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity - return Losses.cosineSimilarity(getTF(), labels, predictions, axis); + public Operand call(Operand labels, Operand predictions) { + // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, + // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.cosineSimilarity(getTF(), tLabels, tPredictions, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 3b84b81e071..7a37cbeddbe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the hinge loss metric between labels and predictions. * @@ -44,7 +46,9 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.hinge(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.hinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index f631f562e1d..3027bb2f460 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the Kullback-Leibler divergence loss metric between labels and * predictions. @@ -45,7 +47,9 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 046937e228b..ca84e651988 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric * between labels and predictions. @@ -45,7 +47,9 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.logCosh(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.logCosh(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 977f61648a1..c91cb0df1ef 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsoluteError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index bad5255969a..6cc96a4fb88 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 5b0d9ec43b3..1fce9998270 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 35044fee956..900359db88b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the mean of absolute difference between labels and predictions. * @@ -45,7 +47,9 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 700099d3375..3572c155b96 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the poisson loss metric between labels and predictions. * @@ -44,7 +46,9 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.poisson(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.poisson(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index aa7ca316378..a74f575a4a8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. \ @@ -55,7 +57,9 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 01f4a403f84..6bee2ccf8e4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * A metric that computes the squared hinge loss metric between labels and predictions. * @@ -44,7 +46,9 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { - return Losses.squaredHinge(getTF(), labels, predictions); + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Losses.squaredHinge(getTF(), tLabels, tPredictions); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 037d634cd4a..1fb3d3bb580 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -31,5 +31,5 @@ public interface LossMetric { * @param predictions the predictions * @return the loss */ - Operand call(Operand labels, Operand predictions); + Operand call(Operand labels, Operand predictions); } From 74a548f6373c0b564dba1b8b23a6110dd5c8eed0 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 6 Feb 2021 16:17:11 -0500 Subject: [PATCH 48/97] Reformat code, fix javadoc --- .../src/gen/annotations/org/tensorflow/op/Ops.java | 6 +++--- .../main/java/org/tensorflow/framework/losses/Hinge.java | 7 +------ .../main/java/org/tensorflow/framework/losses/Huber.java | 4 +++- .../main/java/org/tensorflow/framework/losses/LogCosh.java | 3 ++- .../main/java/org/tensorflow/framework/losses/Losses.java | 1 + .../org/tensorflow/framework/losses/impl/LossesHelper.java | 7 ++++++- .../tensorflow/framework/metrics/BinaryCrossentropy.java | 3 ++- .../framework/metrics/CategoricalCrossentropy.java | 3 ++- .../org/tensorflow/framework/metrics/CategoricalHinge.java | 3 ++- .../org/tensorflow/framework/metrics/CosineSimilarity.java | 3 ++- .../main/java/org/tensorflow/framework/metrics/Hinge.java | 3 ++- .../org/tensorflow/framework/metrics/KLDivergence.java | 3 ++- .../org/tensorflow/framework/metrics/LogCoshError.java | 3 ++- .../tensorflow/framework/metrics/MeanAbsoluteError.java | 3 ++- .../framework/metrics/MeanAbsolutePercentageError.java | 3 ++- .../org/tensorflow/framework/metrics/MeanSquaredError.java | 3 ++- .../framework/metrics/MeanSquaredLogarithmicError.java | 3 ++- .../java/org/tensorflow/framework/metrics/Poisson.java | 3 ++- .../framework/metrics/SparseCategoricalCrossentropy.java | 3 ++- .../org/tensorflow/framework/metrics/SquaredHinge.java | 3 ++- 20 files changed, 44 insertions(+), 26 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 7d6d159f5ef..acbae4dac6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -364,10 +364,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -390,8 +390,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java index db3569441ef..d4c350ef06c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Hinge.java @@ -117,18 +117,13 @@ public Hinge(Ops tf, String name, Reduction reduction) { * predictions is scaled by the corresponding value of SampleWeights. (Note on dN-1: all loss * functions reduce by 1 dimension, usually axis=-1.) * @param The data type of the predictions, sampleWeights and loss. - * @param The data type of the labels. * @return the loss * @throws IllegalArgumentException if the predictions are outside the range [0.-1.]. */ @Override public Operand call( Operand labels, Operand predictions, Operand sampleWeights) { - @SuppressWarnings("unchecked") - Operand tLabels = - predictions.type() == labels.type() - ? (Operand) labels - : cast(tf, labels, predictions.type()); + Operand tLabels = cast(tf, labels, predictions.type()); tLabels = LossesHelper.valueCheck( getTF(), diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java index 665a9ac157d..b1aee1b0656 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Huber.java @@ -89,6 +89,7 @@ public Huber(Ops tf) { * Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public Huber(Ops tf, String name) { this(tf, name, DELTA_DEFAULT, Reduction.AUTO); @@ -109,6 +110,7 @@ public Huber(Ops tf, Reduction reduction) { * Creates a Huber Loss using {@link #DELTA_DEFAULT} as the delta * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public Huber(Ops tf, String name, Reduction reduction) { @@ -119,7 +121,7 @@ public Huber(Ops tf, String name, Reduction reduction) { * Creates a Huber Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param delta the point where the Huber loss function changes from quadratic to linear. * @param reduction Type of Reduction to apply to the loss. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java index 78325713e3e..a11d582e527 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java @@ -77,6 +77,7 @@ public LogCosh(Ops tf) { * Creates a LogCosh Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} * * @param tf the TensorFlow Ops + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. */ public LogCosh(Ops tf, String name) { this(tf, name, Reduction.AUTO); @@ -96,7 +97,7 @@ public LogCosh(Ops tf, Reduction reduction) { * Creates a LogCosh Loss * * @param tf the TensorFlow Ops - * @param name the name of the loss + * @param name the name of the loss, if null then {@link Class#getSimpleName()} is used. * @param reduction Type of Reduction to apply to the loss. */ public LogCosh(Ops tf, String name, Reduction reduction) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 2222ebb41f8..9aa94cf7fcf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -655,6 +655,7 @@ private static Operand smoothCategoricalLabels( * @param tf The TensorFlow Ops * @param x the input * @param axis Dimension along which to normalize. + * @param the data type for the input and the result * @return the normalized values based on L2 norm */ public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index 66bdd839f09..f6b0de71b0d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -53,6 +53,7 @@ public class LossesHelper { * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction * . + * @param the data type for the labels, predictions and result * @return LossTuple of prediction, label,sampleWeight will * be null. Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, (prediction, @@ -81,6 +82,7 @@ public static LossTuple squeezeOrExpandDimensions( * @param sampleWeights Optional sample weight(s) Operand whose dimensions match * * prediction. + * @param the data type for the labels, predictions and result * @return LossTuple of predictions, labels and sampleWeight * . Each of them possibly has the last dimension squeezed, sampleWeight * could be extended by one dimension. If sampleWeight is null, only the possibly @@ -180,6 +182,7 @@ private static Operand maybeExpandWeights( * @param labels Label values, a Tensor whose dimensions match predictions * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -195,6 +198,7 @@ public static LossTuple removeSqueezableDimensions( * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). + * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. */ public static LossTuple removeSqueezableDimensions( @@ -218,7 +222,8 @@ public static LossTuple removeSqueezableDimensions( } // Use dynamic rank. - // TODO Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + // TODO: hold for lazy select feature, + // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { /* * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 263b8a789ed..48ee244eafb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -62,7 +62,8 @@ public BinaryCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.binaryCrossentropy(getTF(), tLabels, tPredictions, fromLogits, labelSmoothing); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index cbe0127295f..b22e5415f79 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -101,7 +101,8 @@ public CategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalCrossentropy( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index ff814ae6ed3..4266cc487c0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -47,7 +47,8 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.categoricalHinge(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index d64136d0d90..840f255c5ab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -78,7 +78,8 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { // NOTE: metrics.CosineSimilarity is Losses.cosineSimilarity, // while losses.CosineSimilarity is the negative of Losses.cosineSimilarity Operand tLabels = cast(getTF(), labels, getResultType()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 7a37cbeddbe..46ccd2859ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -46,7 +46,8 @@ public Hinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.hinge(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 3027bb2f460..9ffcd6189f1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -47,7 +47,8 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.kullbackLeiblerDivergence(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index ca84e651988..59e24f57110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -47,7 +47,8 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.logCosh(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index c91cb0df1ef..1cc6d0b6f99 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -47,7 +47,8 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanAbsoluteError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 6cc96a4fb88..8c6720b58f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -47,7 +47,8 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanAbsolutePercentageError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 1fce9998270..3c4c79d39ba 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -47,7 +47,8 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanSquaredError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index 900359db88b..d525bb76648 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -47,7 +47,8 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.meanSquaredLogarithmicError(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 3572c155b96..422fd4808ff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -46,7 +46,8 @@ public Poisson(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.poisson(getTF(), tLabels, tPredictions); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index a74f575a4a8..9949f0c6b60 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -57,7 +57,8 @@ public SparseCategoricalCrossentropy( /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.sparseCategoricalCrossentropy(getTF(), tLabels, tPredictions, fromLogits, axis); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 6bee2ccf8e4..19b3b1d0ac4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -46,7 +46,8 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Losses.squaredHinge(getTF(), tLabels, tPredictions); From 0eab19cf0a823686ee7519afc8316add5b17303c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Feb 2021 13:51:02 -0500 Subject: [PATCH 49/97] Update with new generic parameters --- .../org/tensorflow/framework/metrics/AUC.java | 1018 +++++++++++++++++ .../framework/metrics/AUCCurve.java | 36 + .../framework/metrics/AUCSummationMethod.java | 41 + .../framework/metrics/Accuracy.java | 89 ++ .../framework/metrics/BinaryAccuracy.java | 100 ++ .../metrics/CategoricalAccuracy.java | 85 ++ .../framework/metrics/FalseNegatives.java | 128 +++ .../framework/metrics/FalsePositives.java | 129 +++ .../tensorflow/framework/metrics/MeanIoU.java | 163 +++ .../framework/metrics/MeanRelativeError.java | 173 +++ .../framework/metrics/MeanTensor.java | 186 +++ .../tensorflow/framework/metrics/Metrics.java | 52 +- .../framework/metrics/Precision.java | 400 +++++++ .../framework/metrics/PrecisionAtRecall.java | 122 ++ .../tensorflow/framework/metrics/Recall.java | 426 +++++++ .../framework/metrics/RecallAtPrecision.java | 132 +++ .../metrics/RootMeanSquaredError.java | 87 ++ .../metrics/SensitivityAtSpecificity.java | 150 +++ .../metrics/SparseCategoricalAccuracy.java | 135 +++ .../SparseTopKCategoricalAccuracy.java | 70 ++ .../metrics/SpecificityAtSensitivity.java | 151 +++ .../org/tensorflow/framework/metrics/Sum.java | 60 + .../metrics/TopKCategoricalAccuracy.java | 70 ++ .../framework/metrics/TrueNegatives.java | 129 +++ .../framework/metrics/TruePositives.java | 128 +++ .../impl/ConfusionMatrixConditionCount.java | 186 +++ .../metrics/impl/ConfusionMatrixEnum.java | 57 + .../framework/metrics/impl/MetricsHelper.java | 487 +++++++- .../impl/SensitivitySpecificityBase.java | 277 +++++ .../framework/metrics/impl/SymbolicShape.java | 56 + .../metrics/impl/WeightsBroadcastOps.java | 186 +++ .../framework/utils/SparseTensor.java | 75 ++ .../tensorflow/framework/metrics/AUCTest.java | 324 ++++++ .../framework/metrics/AccuracyTest.java | 130 +++ .../framework/metrics/BinaryAccuracyTest.java | 177 +++ .../metrics/CategoricalAccuracyTest.java | 156 +++ .../framework/metrics/FalseNegativesTest.java | 141 +++ .../framework/metrics/FalsePositivesTest.java | 148 +++ .../framework/metrics/MeanIoUTest.java | 109 ++ .../metrics/MeanRelativeErrorTest.java | 100 ++ .../framework/metrics/MeanTensorTest.java | 119 ++ .../metrics/PrecisionAtRecallTest.java | 179 +++ .../framework/metrics/PrecisionTest.java | 339 ++++++ .../metrics/RecallAtPrecisionTest.java | 207 ++++ .../framework/metrics/RecallTest.java | 341 ++++++ .../metrics/RootMeanSquaredErrorTest.java | 72 ++ .../metrics/SensitivityAtSpecificityTest.java | 185 +++ .../metrics/SpecificityAtSensitivityTest.java | 184 +++ .../tensorflow/framework/metrics/SumTest.java | 113 ++ .../metrics/TopKCategoricalAccuracyTest.java | 103 ++ .../framework/metrics/TrueNegativesTest.java | 141 +++ .../framework/metrics/TruePositivesTest.java | 141 +++ .../metrics/impl/AssertBroadcastableTest.java | 1 + 53 files changed, 8988 insertions(+), 6 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java new file mode 100644 index 00000000000..62311c3cda5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -0,0 +1,1018 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.framework.metrics.impl.SymbolicShape; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.*; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. + * + *

This metric creates four local variables, truePositives`, trueNegatives`, + * falsePositives` and falseNegatives` that are used to compute the AUC. To discretize the AUC + * curve, a linearly spaced set of thresholds is used to compute pairs of recall and precision + * values. The area under the ROC-curve is therefore computed using the height of the recall values + * by the false positive rate, while the area under the PR-curve is the computed using the height of + * the precision values by the recall. + * + *

This value is ultimately returned as auc, an idempotent operation that computes the area + * under a discretized curve of precision versus recall values (computed using the aforementioned + * variables). The numThresholds variable controls the degree of discretization with larger + * numbers of thresholds more closely approximating the true AUC. The quality of the approximation + * may vary dramatically depending on numThresholds`. The thresholds parameter can be used to + * manually specify thresholds which split the predictions more evenly. + * + *

For best results, predictions should be distributed approximately uniformly in the range [0, + * 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if this is not + * the case. Setting summationMethod to minoring or majoring can help quantify the error in + * the approximation by providing lower or upper bound estimate of the AUC. + *

+ *

+ * Usage:
+ *

+ * AUC m = new  getTF().keras.metrics.AUC( getTF(), 3);
+ * m.updateState( getTF().constant(new float[] {0, 0, 1,1}),
+ *          getTF().constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
+ *
+ * // threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
+ * // tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
+ * // recall = [1, 0.5, 0], fpRate = [1, 0, 0]
+ * // auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75
+ * Operand<TFloat32> result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 0.75
+ * 
+ *
+ * m.resetStates()
+ * m.updateState( getTF().constant(new float[] {0, 0, 1, 1}),
+ *                 getTF().constant(new float[] {0f, 0.5f, 0.3f, 0.9f}, ),
+ *                 getTF().constant(new float[] {1, 0, 0, 1}));
+ * result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 1.0
+ * 
+ * + * @param The data type for the metric result + */ +public class AUC extends Metric { + + /** Default Fuzz factor. */ + public static final float EPSILON = 1e-7f; + + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; + public static final String TRUE_NEGATIVES = "TRUE_NEGATIVES"; + public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; + public static final int DEFAULT_NUM_THRESHOLDS = 200; + public static final String DEFAULT_NAME = "auc"; + + private final int numThresholds; + private final AUCCurve curve; + private final AUCSummationMethod summationMethod; + private final float[] thresholds; + private final boolean multiLabel; + private final String truePositivesName; + private final String falsePositivesName; + private final String trueNegativesName; + private final String falseNegativesName; + private final Map> initializers = new HashMap<>(); + private final Class type; + private Integer numLabels; + private Operand labelWeights; + private Variable truePositives; + private Variable falsePositives; + private Variable trueNegatives; + private Variable falseNegatives; + private boolean initialized; + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link #DEFAULT_NUM_THRESHOLDS} for the numThresholds, {@link AUCCurve#ROC} for the curve type, + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for + * thresholds, false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, long seed, Class type) { + this( + tf, + null, + DEFAULT_NUM_THRESHOLDS, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NUM_THRESHOLDS} for the + * numThresholds, {@link AUCCurve#ROC} for the curve type, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, null for thresholds, + * false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, long seed, Class type) { + this( + tf, + name, + DEFAULT_NUM_THRESHOLDS, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the + * summation method, null for thresholds, false for multiLabel, and + * null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, int numThresholds, long seed, Class type) { + this( + tf, + null, + numThresholds, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the + * summation method, null for numThresholds, false for multiLabel, and + * null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, float[] thresholds, long seed, Class type) { + this( + tf, + null, + null, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric. using {@link AUCCurve#ROC} for the curve type, + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for + * thresholds, false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, int numThresholds, long seed, Class type) { + this( + tf, + name, + numThresholds, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the summation + * method, {@link #DEFAULT_NUM_THRESHOLDS} num thresholds, false for multiLabel, and + * null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { + this( + tf, + name, + null, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link AUCSummationMethod#INTERPOLATION} for + * the summation method, null for thresholds, false for multiLabel, and + * null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, int numThresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + name, + numThresholds, + curve, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, {@link #DEFAULT_NUM_THRESHOLDS} num + * thresholds, false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + name, + null, + curve, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for + * thresholds, false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + null, + numThresholds, + curve, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, false for multiLabel, + * and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + null, + null, + curve, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric. using {@link #DEFAULT_NAME} for the metric name,, + * null for thresholds, false for multiLabel, and null for + * labelWeights. + * + * @param tf The TensorFlow Ops + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + int numThresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, null, numThresholds, curve, summationMethod, null, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * null for numThresholds, false for multiLabel, and null + * for labelWeights. + * + * @param tf The TensorFlow Ops + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + float[] thresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, null, null, curve, summationMethod, thresholds, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric. using null for thresholds, + * false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used, + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + String name, + int numThresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, name, numThresholds, curve, summationMethod, null, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric. using null> for the numThresholds, + * false for multiLabel, and null for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + String name, + float[] thresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, name, null, curve, summationMethod, thresholds, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS} + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein + * AUC is computed separately for each label and then averaged across labels, or (when false) + * if the data should be flattened into a single label before AUC computation. In the latter + * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an + * individual data point. Should be set to false for multi-class data. + * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When + * multi_label is True, the weights are applied to the individual label AUCs when they are + * averaged to produce the multi-label AUC. When it's false, they are used to weight the + * individual label predictions in computing the confusion matrix on the flattened data. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + * @throws IllegalArgumentException if numThresholds is less than 2 and thresholds is null, or if + * a threshold value is less than 0 or greater than 1. + */ + public AUC( + Ops tf, + String name, + Integer numThresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + float[] thresholds, + boolean multiLabel, + Operand labelWeights, + long seed, + Class type) { + super(tf, name == null ? DEFAULT_NAME : name, seed); + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); + this.trueNegativesName = this.getVariableName(TRUE_NEGATIVES); + this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); + this.curve = curve; + this.summationMethod = summationMethod; + this.type = type; + + this.multiLabel = multiLabel; + + if (thresholds != null) { // ignore numThresholds + for (float t : thresholds) + if (t < 0.0f || t > 1.0f) + throw new IllegalArgumentException( + String.format( + "Threshold values must be in [0, 1]. Invalid values: %s", + Arrays.toString(thresholds))); + this.numThresholds = thresholds.length + 2; + Arrays.sort(thresholds); + } else { + if (numThresholds <= 1) throw new IllegalArgumentException("numThresholds must be > 1."); + this.numThresholds = numThresholds; + thresholds = new float[numThresholds - 2]; + // linearly interpolate (numThresholds - 2) thresholds between endpoints + for (int i = 0; i < thresholds.length; i++) { + thresholds[i] = (i + 1) * 1.0f / (this.numThresholds - 1); + } + } + // Add an endpoint "threshold" below zero and above one for either + // threshold method to account for floating point imprecision. + if (thresholds.length != this.numThresholds - 2) + throw new IllegalArgumentException( + "Thresholds length must contain numThresholds - 2 entries"); + this.thresholds = new float[this.numThresholds]; + this.thresholds[0] = -EPSILON; + System.arraycopy(thresholds, 0, this.thresholds, 1, thresholds.length); + this.thresholds[this.numThresholds - 1] = 1 + EPSILON; + + if (labelWeights != null) { + // assert that labelWeights are non-negative. + + this.labelWeights = labelWeights; + Op checks = + getTF() + .withSubScope("AUC") + .assertThat( + getTF() + .math + .greaterEqual( + labelWeights, cast(getTF(), getTF().constant(0), labelWeights.type())), + Collections.singletonList( + getTF().constant("All values of `labelWeights` must be non-negative."))); + + Ops ltf = + getTF() + .withSubScope("updateState") + .withControlDependencies(Collections.singletonList(checks)); + + this.labelWeights = ltf.identity(this.labelWeights); + } + + if (this.multiLabel) { + this.numLabels = null; + } + } + + /** + * Initialize truePositives, falsePositives, trueNegatives, and falseNegatives variables, given + * the shape of the data. + * + * @param shape the prediction shape if called from updateState, otherwise null + */ + @SuppressWarnings("unchecked") + private Map> build(Shape shape) { + Shape variableShape; + if (initialized) { + return Collections.EMPTY_MAP; + } + + if (this.isMultiLabel()) { + if (shape == null) { + throw new IllegalArgumentException("For multiLabel, a shape must be provided"); + } + if (shape.numDimensions() != 2) + throw new IllegalArgumentException( + String.format( + "labels must have rank=2 when multiLabel is true. Found rank %d.", + shape.numDimensions())); + this.numLabels = (int) shape.size(1); + variableShape = Shape.of(this.numThresholds, this.numLabels); + } else { + variableShape = Shape.of(this.numThresholds); + } + + Zeros zeros = new Zeros<>(getTF()); + Operand zero = zeros.call(getTF().constant(variableShape), type); + if (truePositives == null) { + truePositives = getTF().withName(getTruePositivesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, getTF().assign(truePositives, zero)); + } + + if (falsePositives == null) { + falsePositives = getTF().withName(getFalsePositivesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.FALSE_POSITIVES, getTF().assign(falsePositives, zero)); + } + + if (trueNegatives == null) { + trueNegatives = getTF().withName(getTrueNegativesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.TRUE_NEGATIVES, getTF().assign(trueNegatives, zero)); + } + + if (falseNegatives == null) { + falseNegatives = getTF().withName(getFalseNegativesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, getTF().assign(falseNegatives, zero)); + } + + this.initialized = true; + return initializers; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + + Operand lLabels = cast(getTF(), labels, type); + Operand lPredictions = cast(getTF(), predictions, type); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + List updateOperations = new ArrayList<>(); + Map> varInitializers = Collections.EMPTY_MAP; + if (!this.initialized) { + varInitializers = build(lPredictions.shape()); + } + if (this.isMultiLabel() || this.getLabelWeights() != null) { + List> symbols = new ArrayList<>(); + symbols.add(new SymbolicShape<>(lLabels, "N", "L")); + if (this.isMultiLabel()) { + symbols.add(new SymbolicShape<>(this.truePositives, "T", "L")); + symbols.add(new SymbolicShape<>(this.falsePositives, "T", "L")); + symbols.add(new SymbolicShape<>(this.trueNegatives, "T", "L")); + symbols.add(new SymbolicShape<>(this.falseNegatives, "T", "L")); + } + if (this.getLabelWeights() != null) { + symbols.add(new SymbolicShape<>(this.getLabelWeights(), "L", "")); + } + updateOperations.addAll( + MetricsHelper.assertShapes(getTF(), symbols, "Number of labels is not consistent.")); + } + if (this.isMultiLabel()) { + this.labelWeights = null; + } + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.falsePositives); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.trueNegatives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives); + + updateOperations.addAll( + MetricsHelper.updateConfusionMatrixVariables( + getTF(), + confusionMatrix, + varInitializers, + lLabels, + lPredictions, + this.thresholds, + null, + null, + tSampleWeights, + this.isMultiLabel(), + this.getLabelWeights())); + return updateOperations; + } + + /** + * Interpolation formula inspired by section 4 of Davis & Goadrich 2006. + * + * @return an approximation of the area under the P-R curve. + */ + private Operand interpolatePRAuc() { + // truePositives[:self.numThresholds - 1] + Operand tp0 = + getTF() + .slice( + truePositives, + getTF().constant(new int[] {0}), + getTF().constant(new int[] {this.getNumThresholds() - 1})); + // truePositives[1:] + Operand tp1 = + getTF() + .slice( + truePositives, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1})); + + Operand dTP = getTF().math.sub(tp0, tp1); + + Operand p = getTF().math.add(truePositives, falsePositives); + + Operand dP = + getTF() + .math + .sub( + getTF() + .slice( + p, + getTF().constant(new int[] {0}), + getTF().constant(new int[] {this.getNumThresholds() - 1})), + getTF() + .slice(p, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1}))); + + Operand precisionSlope = + getTF() + .math + .divNoNan( + dTP, getTF().math.maximum(dP, getTF().dtypes.cast(getTF().constant(0), dP.type()))); + + Operand intercept = + getTF() + .math + .sub( + getTF() + .slice( + truePositives, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})), + getTF() + .math + .mul( + precisionSlope, + getTF() + .slice( + p, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})))); + + Operand safePRatio = + getTF() + .select( + getTF() + .math + .logicalAnd( + getTF() + .math + .greater( + getTF() + .slice( + p, + getTF().constant(new int[] {0}), + getTF().constant(new int[] {this.getNumThresholds() - 1})), + getTF().dtypes.cast(getTF().constant(0), p.type())), + getTF() + .math + .greater( + getTF() + .slice( + p, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})), + getTF().dtypes.cast(getTF().constant(0), p.type()))), + getTF() + .math + .divNoNan( + getTF() + .slice( + p, + getTF().constant(new int[] {0}), + getTF().constant(new int[] {this.getNumThresholds() - 1})), + getTF() + .math + .maximum( + getTF() + .slice( + p, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})), + getTF().dtypes.cast(getTF().constant(0), p.type()))), + getTF() + .onesLike( + getTF() + .slice( + p, + getTF().constant(new int[] {1}), + getTF().constant(new int[] {-1})))); + + Operand fn1 = + getTF() + .slice( + falseNegatives, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1})); + + Operand aucTotalPos = + getTF() + .math + .mul( + precisionSlope, + getTF().math.add(dTP, getTF().math.mul(intercept, getTF().math.log(safePRatio)))); + + Operand prAucIncrement = + getTF() + .math + .divNoNan( + aucTotalPos, + getTF() + .math + .maximum( + getTF().math.add(tp1, fn1), + getTF().dtypes.cast(getTF().constant(0), this.truePositives.type()))); + + if (this.isMultiLabel()) { + Operand byLabelAuc = getTF().reduceSum(prAucIncrement, getTF().constant(0)); + if (this.getLabelWeights() == null) { + return MetricsHelper.mean(getTF(), byLabelAuc); + } else { + return getTF() + .math + .divNoNan( + getTF() + .reduceSum( + getTF().math.mul(byLabelAuc, this.getLabelWeights()), + allAxes(getTF(), byLabelAuc)), + getTF().reduceSum(getLabelWeights(), allAxes(getTF(), getLabelWeights()))); + } + } else { + return getTF().reduceSum(prAucIncrement, allAxes(getTF(), prAucIncrement)); + } + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + + if (this.getCurve() == AUCCurve.PR + && this.getSummationMethod() == AUCSummationMethod.INTERPOLATION) { + return this.interpolatePRAuc(); + } + Ops tf = getTF(); + Operand x; + Operand y; + Operand recall = + getTF().math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + + if (this.getCurve() == AUCCurve.ROC) { + x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives)); + y = recall; + } else { // AUCCurve.PR + y = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + x = recall; + } + + // Find the rectangle heights based on `summationMethod`. + // y[:self.numThresholds - 1] + Operand ySlice1 = + tf.slice( + y, tf.constant(new int[] {0}), tf.constant(new int[] {this.getNumThresholds() - 1})); + // y[1:] + Operand ySlice2 = tf.slice(y, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + + Operand heights = null; + switch (this.getSummationMethod()) { + case INTERPOLATION: + heights = + tf.math.div(tf.math.add(ySlice1, ySlice2), tf.dtypes.cast(tf.constant(2), y.type())); + break; + case MINORING: + heights = tf.math.minimum(ySlice1, ySlice2); + break; + case MAJORING: + heights = tf.math.maximum(ySlice1, ySlice2); + break; + } + + if (this.isMultiLabel()) { + Operand riemannTerms = + tf.math.mul( + tf.math.sub( + tf.slice( + x, + tf.constant(new int[] {0}), + tf.constant(new int[] {this.getNumThresholds() - 1})), + tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))), + heights); + Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0)); + + if (this.getLabelWeights() == null) { + return MetricsHelper.mean(tf, byLabelAuc); + } else { + return tf.math.divNoNan( + tf.reduceSum( + tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())), + tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))); + } + + } else { + Operand slice1 = + tf.slice( + x, tf.constant(new int[] {0}), tf.constant(new int[] {this.getNumThresholds() - 1})); + Operand slice2 = tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand sub = tf.math.sub(slice1, slice2); + Operand operand = tf.math.mul(sub, heights); + return tf.reduceSum(operand, allAxes(tf, operand)); + } + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + List updateOperations = new ArrayList<>(initializers.values()); + return getTF().withSubScope("resetStates").withControlDependencies(updateOperations).noOp(); + } + + /** @return the numThresholds */ + public int getNumThresholds() { + return numThresholds; + } + + /** @return the curve */ + public AUCCurve getCurve() { + return curve; + } + + /** @return the summationMethod */ + public AUCSummationMethod getSummationMethod() { + return summationMethod; + } + + /** @return the thresholds */ + public float[] getThresholds() { + return thresholds; + } + + /** @return the multiLabel */ + public boolean isMultiLabel() { + return multiLabel; + } + + /** @return the numLabels */ + public Integer getNumLabels() { + return numLabels; + } + + /** @param numLabels the numLabels to set */ + public void setNumLabels(Integer numLabels) { + this.numLabels = numLabels; + } + + /** @return the labelWeights */ + public Operand getLabelWeights() { + return labelWeights; + } + + /** @return the truePositives */ + public Variable getTruePositives() { + return truePositives; + } + + /** @return the falsePositives */ + public Variable getFalsePositives() { + return falsePositives; + } + + /** @return the trueNegatives */ + public Variable getTrueNegatives() { + return trueNegatives; + } + + /** @return the falseNegatives */ + public Variable getFalseNegatives() { + return falseNegatives; + } + + /** @return the truePositivesName */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** @return the falsePositivesName */ + public String getFalsePositivesName() { + return falsePositivesName; + } + + /** @return the trueNegativesName */ + public String getTrueNegativesName() { + return trueNegativesName; + } + + /** @return the falseNegativesName */ + public String getFalseNegativesName() { + return falseNegativesName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java new file mode 100644 index 00000000000..b5426a0dd8f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +/** + * Specifies the type of the curve to be computed, {@link #ROC} for a Receiver Operator + * Characteristic curve [default] or {@link #PR} for a Precision-Recall-curve. + */ +public enum AUCCurve { + /** Receiver Operator Characteristic curve */ + ROC, + /** Precision-Recall-curve */ + PR; + + /** + * Gets the AUCCurve enum value by name, regardless of case + * + * @param name the name of the AUCCurve enum value. + * @return the AUCCurve enum value. + */ + public AUCCurve get(String name) { + return AUCCurve.valueOf(name.toUpperCase()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java new file mode 100644 index 00000000000..09581c726d3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +/** + * Specifies the Riemann summation method used. {@link #INTERPOLATION} (default) applies mid-point + * summation scheme for ROC. For PR-AUC, interpolates (true/false) positives but not the ratio that + * is precision (see Davis & Goadrich 2006 for details); {@link #MINORING} applies left summation + * for increasing intervals and right summation for decreasing intervals; {@link #MAJORING} does the + * opposite. + * + * @see Davis & Goadrich. 2006 + * @see Riemann summation method + */ +public enum AUCSummationMethod { + INTERPOLATION, + MAJORING, + MINORING; + + /** + * Gets the AUCSummationMethod enum value by name, regardless of case + * + * @param name the name of the AUCSummationMethod enum value. + * @return the AUCSummationMethod enum value. + */ + public AUCSummationMethod get(String name) { + return AUCSummationMethod.valueOf(name.toUpperCase()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java new file mode 100644 index 00000000000..f69170e57b9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that calculates how often predictions equals labels. + * + *

This metric creates two local variables, total and count that are used to compute the + * frequency with which predictions matches labels. This frequency is + * ultimately returned as binary accuracy: an idempotent operation that simply divides total by + * count. + * + *

If sampleWeights is null>, weights default to 1. Use sampleWeights of 0 to mask + * values. + * + * @param The data type for the metric result + */ +public class Accuracy extends MeanMetricWrapper implements LossMetric { + + /** + * Creates an Accuracy Metric using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Accuracy(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates an Accuracy Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Accuracy(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + LossTuple tuple = + MetricsHelper.raggedAssertCompatibleAndGetFlatValues(getTF(), tLabels, tPredictions); + tLabels = tuple.getLabels(); + tPredictions = tuple.getTarget(); + + if (!predictions.shape().isCompatibleWith(labels.shape())) { + throw new IllegalArgumentException( + String.format( + "Shapes %s and %s are incompatible", + predictions.shape().toString(), labels.shape().toString())); + } + + // cast TBool to result type + return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java new file mode 100644 index 00000000000..9e7f0f874cc --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that calculates how often predictions matches binary labels. + * + *

This metric creates two local variables, total and count that are used to compute the + * frequency with which predictions matches labels. This frequency is + * ultimately returned as binary accuracy: an idempotent operation that simply divides total by + * count. + * + *

If sampleWeights is null>, weights default to 1. Use sampleWeights of 0 to mask + * values. + * + * @param The data type for the metric result + */ +public class BinaryAccuracy extends MeanMetricWrapper + implements LossMetric { + /** the default threshold value for deciding whether prediction values are 1 or 0 */ + public static final float DEFAULT_THRESHOLD = 0.5f; + + /** the threshold value for deciding whether prediction values are 1 or 0 */ + private final float threshold; + + /** + * Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name and + * {@link #DEFAULT_THRESHOLD} for the threshold value. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryAccuracy(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold for deciding whether prediction values are 1 or 0 + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryAccuracy(Ops tf, float threshold, long seed, Class type) { + this(tf, null, threshold, seed, type); + } + + /** + * Creates a BinaryAccuracy Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold for deciding whether prediction values are 1 or 0 + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryAccuracy(Ops tf, String name, float threshold, long seed, Class type) { + super(tf, name, seed, type); + this.threshold = threshold; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions) { + + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Operand thresholdCast = cast(getTF(), getTF().constant(threshold), getResultType()); + tPredictions = + cast(getTF(), getTF().math.greater(tPredictions, thresholdCast), getResultType()); + Operand tLabels = cast(getTF(), labels, getResultType()); + return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java new file mode 100644 index 00000000000..c0635746d4d --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that calculates how often predictions matches one-hot labels. + * + *

You can provide logits of classes as predictionsy_pred, since argmax + * of logits and probabilities are same. + * + *

This metric creates two local variables, total and count that are + * used to compute the frequency with which predictions matches labels. + * This frequency is ultimately returned as categorical accuracy: an idempotent operation that + * simply divides total by count. + * + *

predictions and labels should be passed in as vectors of + * probabilities, rather than as labels. If necessary, use {@link + * org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand + * labels as a vector. + * + *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. + * + * @param The data type for the metric result + */ +public class CategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a CategoricalAccuracy metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public CategoricalAccuracy(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a CategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public CategoricalAccuracy(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + super.setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand trueMax = getTF().math.argMax(labels, getTF().constant(-1)); + + Operand predMax = getTF().math.argMax(predictions, getTF().constant(-1)); + return cast(getTF(), getTF().math.equal(trueMax, predMax), getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java new file mode 100644 index 00000000000..cf6f84af512 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of false negatives. + * + *

If sampleWeights is given, calculates the sum of the weights of false negatives. + * This metric creates one local variable, accumulator that is used to keep track of + * the number of false negatives. + * + *

If sampleWeightsnull + * sampleWeights The data type for the metric result + */ +public class FalseNegatives + extends ConfusionMatrixConditionCount { + + /** + * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a FalseNegatives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalseNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalseNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.FALSE_NEGATIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java new file mode 100644 index 00000000000..629caaafb52 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of false positives. + * + *

If sampleWeights is given, calculates the sum of the weights of false positives. + * This metric creates one local variable, accumulator that is used to keep track of + * the number of false positives. + * + *

If sampleWeightsnull + * sampleWeights The data type for the metric result + */ +public class FalsePositives< T extends TNumber> + extends ConfusionMatrixConditionCount { + + /** + * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a FalsePositives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalsePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalsePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.FALSE_POSITIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java new file mode 100644 index 00000000000..c8205565802 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -0,0 +1,163 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the mean Intersection-Over-Union metric. + * + *

Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, + * which first computes the IOU for each semantic class and then computes the average over classes. + * IOU is defined as follows: IOU = true_positive + * / (true_positive + false_positive + false_negative). The predictions are accumulated in a + * confusion matrix, weighted by sample_weight and the metric is then calculated from it. + * + *

If sampleWeight is null, weights default to 1. Use sample_weight of 0 to mask + * values. + * + * @param The data type for the metric result + */ +public class MeanIoU extends Metric { + + public static final String TOTAL_CONFUSION_MATRIX = "TOTAL_CONFUSION_MATRIX"; + private final String totalCMName; + private final Class type; + /** + * The possible number of labels the prediction task can have. This value must be provided, since + * a confusion matrix of dimension = [numClasses, numClasses] will be allocated. + */ + private final long numClasses; + + private Variable totalConfusionMatrix; + private Assign initializer; + + /** + * Creates a metric MeanIoU, using name as {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param numClasses The possible number of labels the prediction task can have + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + protected MeanIoU(Ops tf, long numClasses, long seed, Class type) { + this(tf, null, numClasses, seed, type); + } + + /** + * create a metric with reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param numClasses The possible number of labels the prediction task can have + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + protected MeanIoU(Ops tf, String name, long numClasses, long seed, Class type) { + super(tf, name, seed); + this.type = type; + this.totalCMName = this.getVariableName(TOTAL_CONFUSION_MATRIX); + this.numClasses = numClasses; + init(); + } + + private void init() { + Shape variableShape = Shape.of(numClasses, numClasses); + + if (totalConfusionMatrix == null) { + Zeros zeros = new Zeros<>(getTF()); + totalConfusionMatrix = + getTF().withName(totalCMName).variable(zeros.call(getTF().constant(variableShape), type)); + initializer = + getTF().assign(totalConfusionMatrix, zeros.call(getTF().constant(variableShape), type)); + } + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return initializer; + } + + /** + * Gets the initializer for the totalConfusionMatrix variable + * + * @return the initializer for the totalConfusionMatrix variable + */ + public Assign getInitializer() { + return initializer; + } + + /** {@inheritDoc} */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + + Operand tLabels = cast(getTF(), labels, type); + if (tLabels.shape().numDimensions() > 1) tLabels = getTF().shape.flatten(tLabels); + Operand tPredictions = cast(getTF(), predictions, type); + if (tPredictions.shape().numDimensions() > 1) + tPredictions = getTF().shape.flatten(tPredictions); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) + tSampleWeights = getTF().shape.flatten(tSampleWeights); + + Operand currentCM = + MetricsHelper.confusionMatrix( + getTF(), tLabels, tPredictions, getTF().constant(numClasses), tSampleWeights, type); + return Collections.singletonList(getTF().assignAdd(totalConfusionMatrix, currentCM)); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + Operand sumOverRow = tf.reduceSum(totalConfusionMatrix, tf.constant(0)); + Operand sumOverCol = tf.reduceSum(totalConfusionMatrix, tf.constant(1)); + Operand truePositives = + tf.linalg.matrixDiagPart( + totalConfusionMatrix, + tf.constant(0), + cast(tf, tf.constant(0), totalConfusionMatrix.type())); + Operand denominator = tf.math.add(sumOverRow, tf.math.sub(sumOverCol, truePositives)); + Operand numValidEntries = + tf.reduceSum( + tf.dtypes.cast( + tf.math.notEqual(denominator, cast(tf, tf.constant(0), denominator.type())), type), + allAxes(tf, denominator)); + Operand iou = tf.math.divNoNan(truePositives, denominator); + + Operand iouSum = tf.reduceSum(iou, allAxes(tf, iou)); + return tf.math.divNoNan(iouSum, numValidEntries); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java new file mode 100644 index 00000000000..eb8ccaf76d2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -0,0 +1,173 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the mean relative error by normalizing with the given values. + * + *

This metric creates two local variables, total and count that are + * used to compute the mean relative error. This is weighted by sampleWeight, and it is + * ultimately returned as mean relative error: an idempotent operation that simply divides total by + * count. + * + *

If sampleWeight is null, weights default to 1. Use sample_weight of + * 0 to mask * values. + * + * @param The data type for the metric result + */ +public class MeanRelativeError extends Mean { + private Operand normalizer; + + /** + * create a metric with name = class name and reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param normalizer The normalizer values with same shape as predictions. + */ + protected MeanRelativeError(Ops tf, float[] normalizer, long seed, Class type) { + this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * create a metric with reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, String name, float[] normalizer, long seed, Class type) { + this(tf, name, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * Creates a MeanRelativeError metric with name = class name and reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, double[] normalizer, long seed, Class type) { + this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * create a metric with reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, String name, double[] normalizer, long seed, Class type) { + this(tf, name, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * create a metric with name = class name and reduction = AUTO + * + * @param tf the TensorFlow Ops + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, Operand normalizer, long seed, Class type) { + this(tf, null, normalizer, seed, type); + } + + /** + * create a metric + * + * @param tf the TensorFlow ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError( + Ops tf, String name, Operand normalizer, long seed, Class type) { + super(tf, name, seed, type); + this.normalizer = normalizer; + } + + /** {@inheritDoc} */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Operand tLabels = cast(getTF(), labels, getResultType()); + if (tLabels.shape().numDimensions() > 1) tLabels = getTF().shape.flatten(tLabels); + + Operand tPredictions = cast(getTF(), predictions, getResultType()); + if (tPredictions.shape().numDimensions() > 1) + tPredictions = getTF().shape.flatten(tPredictions); + + LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); + tPredictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + Operand tSampleWeights = + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { + tSampleWeights = getTF().shape.flatten(tSampleWeights); + } + + tuple = LossesHelper.removeSqueezableDimensions(getTF(), normalizer, tPredictions); + normalizer = tuple.getLabels(); + tPredictions = tuple.getTarget(); + + if (!tPredictions.shape().isCompatibleWith(tLabels.shape())) + throw new IllegalArgumentException( + String.format( + "Prediction shape %s is not compatible with labels shape %s", + tPredictions.shape(), tLabels.shape())); + + Operand relativeErrors = + getTF() + .math + .divNoNan( + getTF().math.abs(getTF().math.sub(tLabels, tPredictions)), this.getNormalizer()); + + return super.updateStateList(relativeErrors, tSampleWeights); + } + + /** + * Gets the normalizer Operand + * + * @return the normalizer + */ + public Operand getNormalizer() { + return normalizer; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java new file mode 100644 index 00000000000..d9c767965a6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -0,0 +1,186 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.framework.metrics.impl.WeightsBroadcastOps; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that computes the element-wise (weighted) mean of the given tensors. + * + * @param The data type for the metric result + */ +public class MeanTensor extends Metric { + public static final String TOTAL = "total"; + public static final String COUNT = "count"; + private final String totalName; + private final String countName; + private final Class type; + private Shape shape; + private Variable total; + private Variable count; + private Assign totalInitializer; + private Assign countInitializer; + private boolean initialized; + + /** + * Creates a MeanTensor metric, using {@link Class#getSimpleName()} as the name + * + * @param tf the TensorFlow ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public MeanTensor(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + /** + * Creates a MeanTensor metric + * + * @param tf the TensorFlow ops + * @param name the name of this metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public MeanTensor(Ops tf, String name, long seed, Class type) { + super(tf, name, seed); + this.type = type; + this.totalName = this.getVariableName(TOTAL); + this.countName = this.getVariableName(COUNT); + } + + /** + * Creates the Operations that initialize the total and count variables. + * + * @param shape the shape of the variables + * @return true if the variables need initialization, otherwise false; + */ + private boolean init(Shape shape) { + if (!initialized) { + this.shape = shape; + Zeros zeros = new Zeros<>(getTF()); + Operand zero = zeros.call(getTF().constant(shape), type); + + if (total == null) { + total = getTF().withName(totalName).variable(zero); + totalInitializer = getTF().assign(total, zero); + } + if (count == null) { + count = getTF().withName(countName).variable(zero); + countInitializer = getTF().assign(count, zero); + } + this.initialized = true; + return true; + } else { + return false; + } + } + + /** {@inheritDoc */ + @Override + public List updateStateList( + Operand values, Operand sampleWeights) { + Ops tf = getTF(); + Operand tValues = cast(tf, values, type); + Operand tSampleWeights = null; + if (sampleWeights != null) tSampleWeights = cast(tf, sampleWeights, type); + + boolean needsInitialization = init(values.shape()); + + if (!this.shape.equals(values.shape())) { + throw new IllegalArgumentException( + String.format( + "MeanTensor input values must always have the same shape. Expected shape (set during the first call): %s. Got %s", + this.shape.toString(), values.shape().toString())); + } + + Operand numValues = tf.onesLike(tValues); + if (tSampleWeights != null) { + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); + try { + tSampleWeights = WeightsBroadcastOps.broadcastWeights(tf, tSampleWeights, tValues); + } catch (IllegalArgumentException ex) { + int ndim = values.shape().numDimensions(); + int weightNdim = tSampleWeights.asOutput().shape().numDimensions(); + int[] range = new int[ndim - weightNdim]; + for (int i = weightNdim; i < ndim; i++) { + range[i] = i; + } + tValues = tf.math.mean(tValues, tf.constant(range)); + } + numValues = tf.math.mul(numValues, tSampleWeights); + tValues = tf.math.mul(tValues, tSampleWeights); + } + + List controlOpsPre = new ArrayList<>(); + if (needsInitialization) { + controlOpsPre.add(countInitializer); + controlOpsPre.add(totalInitializer); + } + Ops tf1 = tf.withSubScope("variables").withControlDependencies(controlOpsPre); + + List controlOps = new ArrayList<>(); + controlOps.add(tf1.assignAdd(this.count, numValues)); + controlOps.add(tf1.assignAdd(this.total, tValues)); + return controlOps; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + if (!this.initialized) { + throw new IllegalStateException( + "MeanTensor does not have any result yet. Please use `.update_state(value)` before retrieving the result."); + } + return getTF().math.divNoNan(total, count); + } + + /** @return the total */ + public Variable getTotal() { + return total; + } + + /** @return the count */ + public Variable getCount() { + return count; + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + List controlOpsPre = new ArrayList<>(); + controlOpsPre.add(countInitializer); + controlOpsPre.add(totalInitializer); + return getTF().withSubScope("resetStates").withControlDependencies(controlOpsPre).noOp(); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 95b74bf1eea..e4cc9c3aa3d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -16,15 +16,15 @@ import org.tensorflow.Operand; import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; /** Helper class with built-in metrics functions. */ public class Metrics { - public static final float L2_NORM_EPSILON = 1e-12f; - /** * Computes how often targets are in the top K predictions. * @@ -55,4 +55,52 @@ public static Operand topKCategoricalAccuracy( tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); } + + /** + * Computes how often integer targets are in the top K predictions. + * + *

Standalone usage: + * + *

+   *     Operand<TInt32> labels = tf.constant(new int[]{2, 1});
+   *     Operand<TFloat32> predictions = tf.constant(new float[][]
+   *                            {{0.1f, 0.9f, 0.f8}, {0.05f, 0.95f, 0f}});
+   *     Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
+   *                                    labels, predictions, 3)
+   *     //m.shape().toString == "[2]"
+   * 
+ * + * @param tf the TensorFlow Ops. + * @param labels the ground truth values. + * @param predictions The prediction values. + * @param k Number of top elements to look at for computing accuracy. + * @param the data type for the predictions and results + * @param the data type ofr the labels. + * @return the Operand for the Sparse top K categorical accuracy value. + */ + @SuppressWarnings("unchecked") + public static Operand sparseTopKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, int k) { + Operand tLabels; + if (labels.type() != predictions.type()) + tLabels = CastHelper.cast(tf, labels, predictions.type()); + else tLabels = (Operand) labels; + + int predictionsRank = predictions.shape().numDimensions(); + int labelsRank = tLabels.shape().numDimensions(); + + Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); + if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { + if (predictionsRank > 2) { + castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); + } + if (labelsRank > 1) { + tLabels = tf.shape.flatten(tLabels); + } + } + return CastHelper.cast( + tf, + tf.nn.inTopK(castPredictions, CastHelper.cast(tf, tLabels, TInt32.class), tf.constant(k)), + predictions.type()); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java new file mode 100644 index 00000000000..6b70c6680cb --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -0,0 +1,400 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.*; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the precision of the predictions with respect to the labels. + * + *

The metric creates two local variables, truePositives and falsePositives that are used to + * compute the precision. This value is ultimately returned as precision, an idempotent operation + * that simply divides truePositives by the sum of truePositives and falsePositives. + * + *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values. + * + *

If is set, the metric calculates precision as how often on average a class among the top-k + * classes with the highest predicted values of a batch entry is correct and can be found in the + * label for that entry. + * + *

If classId is specified, the metric calculates precision by considering only the entries in the batch + * for which classId is above the thresholds and/or in the top-k highest predictions, and computing + * the fraction of them for which classId is indeed a correct label. + * + * @param The data type for the metric result + */ +public class Precision extends Metric { + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; + public static final float DEFAULT_THRESHOLD = 0.5f; + + private final float[] thresholds; + private final Integer topK; + private final Integer classId; + private final String truePositivesName; + private final String falsePositivesName; + private final Class type; + private Variable truePositives; + private Variable falsePositives; + private final List initializers = new ArrayList<>(); + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId + * values and with a threshold of {@link #DEFAULT_THRESHOLD).} + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, long seed, Class type) { + this(tf, null, new float[] {DEFAULT_THRESHOLD}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with no topK or classId values with a threshold of {@link + * #DEFAULT_THRESHOLD).} + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, String name, long seed, Class type) { + this(tf, name, new float[] {DEFAULT_THRESHOLD}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId + * values. + * + * @param tf the TensorFlow Ops + * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId + * values. + * + * @param tf the TensorFlow Ops + * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, null, null, seed, type); + } + + /** + * Creates a Precision Metric with no topK or classId values. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with no topK or classId values. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, String name, float[] thresholds, long seed, Class type) { + this(tf, name, thresholds, null, null, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, float threshold, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, thresholds, topK, classId, seed, type); + } + + /** + * Creates a Precision Metric. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, + String name, + float threshold, + Integer topK, + Integer classId, + long seed, + Class type) { + this(tf, name, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Precision Metric. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, + String name, + float[] thresholds, + Integer topK, + Integer classId, + long seed, + Class type) { + super(tf, name, seed); + this.type = type; + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); + float defaultThreshold = topK == null ? DEFAULT_THRESHOLD : MetricsHelper.NEG_INF; + this.thresholds = thresholds == null ? new float[] {defaultThreshold} : thresholds; + this.topK = topK; + this.classId = classId; + + init(); + } + + /** Initializes the variables */ + private void init() { + Ops tf = getTF(); + Zeros zeros = new Zeros<>(tf); + Operand zero = zeros.call(tf.constant(Shape.of(thresholds.length)), type); + + if (this.truePositives == null) { + this.truePositives = + tf.withName(truePositivesName) + .variable(zero); + initializers.add(tf.assign(truePositives, zero)); + + } + if (this.falsePositives == null) { + this.falsePositives = + tf.withName(falsePositivesName) + .variable(zeros.call(tf.constant(Shape.of(thresholds.length)), type)); + initializers.add(tf.assign(falsePositives, zero)); + } + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, falsePositives); + + Operand tPredictions = cast(getTF(), predictions, type); + Operand tLabels = cast(getTF(), labels, type); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + + return new ArrayList( + MetricsHelper.updateConfusionMatrixVariables( + getTF(), + confusionMatrix, + Collections.EMPTY_MAP, + tLabels, + tPredictions, + thresholds, + topK, + classId, + tSampleWeights, + false, + null)); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + Operand result = + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + return thresholds.length == 1 + ? tf.slice( + result, + tf.expandDims(tf.constant(0), tf.constant(0)), + tf.expandDims(tf.constant(1), tf.constant(0))) + : result; + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); + } + + /** + * Gets the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return thresholds; + } + + /** + * Gets the topK value, may be null + * + * @return the topK + */ + public Integer getTopK() { + return topK; + } + + /** + * Gets the classId, may be null + * + * @return the classId + */ + public Integer getClassId() { + return classId; + } + + /** + * Gets the truePositives variable + * + * @return the truePositives + */ + public Variable getTruePositives() { + return truePositives; + } + + /** Gets the falsePositives variable return the falsePositives */ + public Variable getFalsePositives() { + return falsePositives; + } + + /** + * Gets the name of the truePositives variable + * + * @return the truePositivesName + */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** + * Gets the name of the falsePositives variable + * + * @return the falsePositivesName + */ + public String getFalsePositivesName() { + return falsePositivesName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java new file mode 100644 index 00000000000..2ec66df0ca9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best precision where recall is >= specified value. + * @param The data type for the metric result + */ +public class PrecisionAtRecall + extends SensitivitySpecificityBase { + + private final float recall; + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()} and {@link + * #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param recall the recall. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public PrecisionAtRecall(Ops tf, float recall, long seed, Class type) { + this(tf, null, recall, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number of + * thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param recall the recall. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public PrecisionAtRecall(Ops tf, String name, float recall, long seed, Class type) { + this(tf, name, recall, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param recall the recall. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Class type) { + this(tf, null, recall, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param recall the recall. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public PrecisionAtRecall( + Ops tf, String name, float recall, int numThresholds, long seed, Class type) { + super(tf, name, recall, numThresholds, seed, type); + if (recall < 0f || recall > 1f) + throw new IllegalArgumentException("recall must be in the range [0, 1]."); + this.recall = recall; + } + + @Override + public Operand result() { + Ops tf = getTF(); + + Operand recall = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + Operand sub = tf.math.sub(recall, cast(tf, tf.constant(value), getType())); + Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); + minIndex = tf.expandDims(minIndex, tf.constant(0)); + + Operand trueSlice = tf.slice(this.truePositives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(this.falsePositives, minIndex, tf.constant(new int[] {1})); + return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + } + + /** @return the recall */ + public float getRecall() { + return recall; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java new file mode 100644 index 00000000000..0672b78f229 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -0,0 +1,426 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.*; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the recall of the predictions with respect to the labels. + *

This metric creates two local + * variables, truePositives and falseNegatives, that are used to compute the recall. This value is + * ultimately returned as recall, an idempotent operation that simply divides truePositives by the sum of truePositives and falseNegatives. + * + *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values. + * + *

If is set, the metric calculates recall as how often on average a class among the labels of a + * batch entry is in the top-k predictions. + * + *

If classId is specified, the metric calculates recall by considering only the entries in the batch + * for which classId is in the label, and computing the fraction of them for which classId is above + * the threshold and/or in the top-k predictions. + * + * @param The data type for the metric result + */ +public class Recall< T extends TNumber> extends Metric< T> { + public static final float DEFAULT_THRESHOLD = 0.5f; + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; + + private final float[] thresholds; + private final Integer topK; + private final Integer classId; + private final String truePositivesName; + private final String falseNegativesName; + private final Class type; + private Variable truePositives; + private Variable falseNegatives; + private final List initializers = new ArrayList<>(); + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set + * to null, and thresholds set to {@link #DEFAULT_THRESHOLD} + * + * @param tf The TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, long seed, Class type) { + this(tf, null, null, null, null, seed, type); + } + + /** + * Creates a Recall metric with topK and classId set to null and thresholds set to {@link + * #DEFAULT_THRESHOLD}. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, long seed, Class type) { + this(tf, name, null, null, null, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set + * to null. + * + * @param tf The TensorFlow Ops + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, float threshold, long seed, Class type) { + this(tf, null, threshold, null, null, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set + * to null. + * + * @param tf The TensorFlow Ops + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, null, null, seed, type); + } + + /** + * Creates a Recall metric with topK and classId set to null. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, threshold, null, null, seed, type); + } + + /** + * Creates a Recall metric with topK and classId set to null. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, float[] thresholds, long seed, Class type) { + this(tf, name, thresholds, null, null, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()} and using a threshold + * value of {@link #DEFAULT_THRESHOLD}. + * + * @param tf The TensorFlow Ops + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, null, topK, classId, seed, type); + } + + /** + * Creates a Recall metric using a threshold value of {@link #DEFAULT_THRESHOLD}. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, Integer topK, Integer classId, long seed, Class type) { + this(tf, name, null, topK, classId, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()} + * + * @param tf The TensorFlow Ops + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, float threshold, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()} + * + * @param tf The TensorFlow Ops + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall( + Ops tf, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, thresholds, topK, classId, seed, type); + } + + /** + * Creates a Recall metric. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall( + Ops tf, + String name, + float threshold, + Integer topK, + Integer classId, + long seed, + Class type) { + this(tf, name, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Recall metric. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall( + Ops tf, + String name, + float[] thresholds, + Integer topK, + Integer classId, + long seed, + Class type) { + super(tf, name, seed); + this.type = type; + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); + float defaultThreshold = topK == null ? DEFAULT_THRESHOLD : MetricsHelper.NEG_INF; + + this.thresholds = thresholds == null ? new float[] {defaultThreshold} : thresholds; + this.topK = topK; + this.classId = classId; + + init(); + } + + /** Initializes the Variables */ + private void init() { + Ops tf = getTF(); + Zeros zeros = new Zeros<>(tf); + Operand zero = zeros.call(tf.constant(Shape.of(this.thresholds.length)), type); + if (truePositives == null) { + + truePositives = + tf.withName(truePositivesName) + .variable(zero); + initializers.add(tf.assign(truePositives, zero)); + } + + if (this.falseNegatives == null) { + + falseNegatives = + tf.withName(falseNegativesName) + .variable(zero); + initializers.add(tf.assign(falseNegatives, zero)); + } + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives); + + Operand tPredictions = cast(getTF(), predictions, type); + Operand tLabels = cast(getTF(), labels, type); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + + return MetricsHelper.updateConfusionMatrixVariables( + getTF(), + confusionMatrix, + Collections.EMPTY_MAP, + tLabels, + tPredictions, + this.thresholds, + this.topK, + this.classId, + tSampleWeights, + false, + null); + } + + @Override + public Operand result() { + Ops tf = getTF(); + Operand result = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + return this.thresholds.length == 1 + ? tf.slice(result, tf.constant(new int[] {0}), tf.constant(new int[1])) + : result; + } + + /** + * Gets the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return this.thresholds; + } + + /** + * Gets the topK value + * + * @return the topK value + */ + public Integer getTopK() { + return this.topK; + } + + /** + * Gets the class id + * + * @return the class id + */ + public Integer getClassId() { + return this.classId; + } + + /** + * Gets the truePositives variable + * + * @return the truePositives variable + */ + public Variable getTruePositives() { + return this.truePositives; + } + + /** + * Gets the falseNegatives variable + * + * @return the falseNegatives variable + */ + public Variable getFalseNegatives() { + return this.falseNegatives; + } + + /** + * Gets the truePositives variable name + * + * @return the truePositives variable name + */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** + * Gets the falseNegatives variable name + * + * @return the falseNegatives variable name + */ + public String getFalseNegativesName() { + return falseNegativesName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java new file mode 100644 index 00000000000..6c774f0c765 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Where; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class RecallAtPrecision + extends SensitivitySpecificityBase { + + private final float precision; + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()} and {@link + * #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param precision the precision. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public RecallAtPrecision(Ops tf, float precision, long seed, Class type) { + this(tf, null, precision, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number of + * thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric. If null, defaults to {@link Class#getSimpleName()} + * @param precision the precision. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public RecallAtPrecision(Ops tf, String name, float precision, long seed, Class type) { + this(tf, name, precision, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param precision the precision. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, Class type) { + this(tf, null, precision, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param precision the precision. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + */ + public RecallAtPrecision( + Ops tf, String name, float precision, int numThresholds, long seed, Class type) { + super(tf, name, precision, numThresholds, seed, type); + if (precision < 0f || precision > 1f) + throw new IllegalArgumentException("recall must be in the range [0, 1]."); + this.precision = precision; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + + Operand precisions = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falsePositives)); + Operand recalls = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + Operand isFeasible = + tf.math.greaterEqual(precisions, cast(tf, tf.constant(this.value), getType())); + Where feasible = tf.where(isFeasible); + Operand feasibleExists = tf.math.greater(tf.size(feasible), tf.constant(0)); + + Operand gather = + tf.expandDims(tf.gather(recalls, feasible, tf.constant(0)), tf.constant(0)); + return tf.select( + feasibleExists, + tf.reduceMax(gather, allAxes(tf, gather)), + cast(tf, tf.constant(0), getType())); + } + + /** + * Gets the precision + * + * @return the precision + */ + public float getPrecision() { + return precision; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java new file mode 100644 index 00000000000..2133642564b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes root mean squared error metric between labels> and predictions + * . + * + * @param The data type for the metric result + */ +public class RootMeanSquaredError< T extends TNumber> extends Mean< T> { + + /** + * Creates a RootMeanSquaredError metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public RootMeanSquaredError(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a RootMeanSquaredError metric + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public RootMeanSquaredError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + } + + /** {@inheritDoc} */ + @Override + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); + tPredictions = ops.getTarget(); + tLabels = ops.getLabels(); + + Operand errorSquared = + cast(getTF(), getTF().math.squaredDifference(tPredictions, tLabels), getResultType()); + + return super.updateStateList(errorSquared, tSampleWeights); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + return getTF().math.sqrt(getTF().math.divNoNan(this.total, this.count)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java new file mode 100644 index 00000000000..7cf694868e6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -0,0 +1,150 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best sensitivity where sensitivity is >= specified value. + * + *

Sensitivity measures the proportion of actual positives that are correctly + * identified as such (tp / (tp + fn)). + * + *

Specificity measures the proportion of actual negatives that are correctly + * identified as such (tn / (tn + fp)). + * + *

This metric creates four local variables, truePositives, trueNegatives + * , falsePositives and falseNegatives that are used to compute the + * sensitivity at the given specificity. The threshold for the given specificity value is computed + * and used to evaluate the corresponding sensitivity. + * + *

If sampleWeights is null>, weights default to 1. Use sample_weight + * of 0 to mask values. + * + * @see Additional information + * about specificity and sensitivity + * @param The data type for the metric result + */ +public class SensitivityAtSpecificity + extends SensitivitySpecificityBase { + + private final float specificity; + + /** + * Creates a SpecificityAtSensitivity metric with a name of {@link Class#getSimpleName()} and + * {@link #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param specificity the specificity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity(Ops tf, float specificity, long seed, Class type) { + this(tf, null, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a SpecificityAtSensitivity metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number + * of thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param specificity the specificity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity( + Ops tf, String name, float specificity, long seed, Class type) { + this(tf, name, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param specificity the specificity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * specificity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity( + Ops tf, float specificity, int numThresholds, long seed, Class type) { + this(tf, null, specificity, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param specificity the specificity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * specificity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity( + Ops tf, String name, float specificity, int numThresholds, long seed, Class type) { + super(tf, name, specificity, numThresholds, seed, type); + if (specificity < 0f || specificity > 1f) + throw new IllegalArgumentException("specificity must be in the range [0, 1]."); + this.specificity = specificity; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + Operand specificities = + tf.math.divNoNan( + this.trueNegatives, tf.math.add(this.trueNegatives, this.falsePositives)); + Operand sub = + tf.math.sub(specificities, cast(tf, tf.constant(this.getValue()), getType())); + Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); + minIndex = tf.expandDims(minIndex, tf.constant(0)); + + Operand trueSlice = tf.slice(this.truePositives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(this.falseNegatives, minIndex, tf.constant(new int[] {1})); + return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + } + + /** + * Gets the specificity + * + * @return the specificity + */ + public float getSpecificity() { + return specificity; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java new file mode 100644 index 00000000000..156a4995b02 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -0,0 +1,135 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.op.math.Equal; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import java.util.Collections; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Calculates how often predictions matches integer labels. + * + *

You can provide logits of classes as predictions, since argmax of logits and probabilities are + * same. + * + *

This metric creates two local variables, `total` and `count` that are used to compute the + * frequency with which predictions matches labels. This frequency is ultimately returned as `sparse + * categorical accuracy`: an idempotent operation that simply divides `total` by `count`. + * + *

If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values.' + * + *

Usage: + * + *

+ * + *

+ * SparseCategoricalAccuracy m = new tf.keras.metrics.SparseCategoricalAccuracy();
+ * m.update_state(tf.constant(new float[][] {{2}, {1}},
+ *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, [{0.05f, 0.95f, 0f}});
+ * Operand<TFloat32>> result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 0.5
+ * 
+ * + *
+ * m.reset_states()
+ * m.update_state(
+ *     tf.constant(new float[][] {{2}, {1}},
+ *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}},
+ *     tf.constant(new float[] {0.7f, 0.3f});
+ * result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 0.3
+ * 
+ * + *

Usage with tf.keras API: + * + *

+ * Model model = new tf.keras. models.Model(inputs, outputs);
+ * model.compile(
+ *     "sgd",
+ *     loss="mse",
+ *     metrics=["sparse_categorical_accuracy"]);
+ * 
+ * + * @param The data type for the metric result + */ +public class SparseCategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a SparseCategoricalAccuracy metric, using name of {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result + */ + public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a SparseCategoricalAccuracy metric. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null use {@link Class#getSimpleName()} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type of the metric result. + */ + public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + super.setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, + Operand predictions) { + + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Shape predShape = predictions.asOutput().shape(); + Shape labelsShape = labels.asOutput().shape(); + long predictionsRank = predShape.numDimensions(); + long labelsRank = labelsShape.numDimensions(); + + if (predictionsRank != Shape.UNKNOWN_SIZE + && labelsRank != Shape.UNKNOWN_SIZE + && labelsShape.size((int) labelsRank - 1) == 1) { + tLabels = getTF().squeeze(tLabels, Squeeze.axis(Collections.singletonList(labelsRank - 1L))); + } + Operand argMaxPred = + cast( + getTF(), + getTF().math.argMax(tPredictions, getTF().constant(-1L), TInt64.class), + getResultType()); + + Equal equals = getTF().math.equal(tLabels, argMaxPred); + return getTF().dtypes.cast(equals, getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java new file mode 100644 index 00000000000..7db290530cd --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** @param The data type for the metric result */ +public class SparseTopKCategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * top elements to look at for computing accuracy. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Metrics.sparseTopKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java new file mode 100644 index 00000000000..59f6f44c1f2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -0,0 +1,151 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best specificity where sensitivity is >= specified value. Sensitivity + * measures the proportion of actual positives that are correctly identified as such + * (tp / (tp + fn)). + * + *

Specificity measures the proportion of actual negatives that are correctly + * identified as such (tn / (tn + fp)). + * + *

This metric creates four local variables, truePositives, trueNegatives + * , falsePositives and falseNegatives that are used to compute the + * specificity at the given sensitivity. The threshold for the given sensitivity value is computed + * and used to evaluate the corresponding specificity. + * + *

If sampleWeights is null>, weights default to 1. Use sample_weight + * of 0 to mask values. + * + * @see Additional information + * about specificity and sensitivity + * @param The data type for the metric result + */ +public class SpecificityAtSensitivity + extends SensitivitySpecificityBase { + + private final float sensitivity; + + /** + * Creates a SpecificityAtSensitivity metric with a name of {@link Class#getSimpleName()} and + * {@link #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity(Ops tf, float sensitivity, long seed, Class type) { + this(tf, null, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a SpecificityAtSensitivity metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number + * of thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity( + Ops tf, String name, float sensitivity, long seed, Class type) { + this(tf, name, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * sensitivity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity( + Ops tf, float sensitivity, int numThresholds, long seed, Class type) { + this(tf, null, sensitivity, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * sensitivity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity( + Ops tf, String name, float sensitivity, int numThresholds, long seed, Class type) { + super(tf, name, sensitivity, numThresholds, seed, type); + if (sensitivity < 0f || sensitivity > 1f) + throw new IllegalArgumentException("sensitivity must be in the range [0, 1]."); + this.sensitivity = sensitivity; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + + Ops tf = getTF(); + + Operand sensitivities = + tf.math.divNoNan( + this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + Operand sub = + tf.math.sub(sensitivities, cast(tf, tf.constant(this.getValue()), getType())); + Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); + minIndex = tf.expandDims(minIndex, tf.constant(0)); + + Operand trueSlice = tf.slice(this.trueNegatives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(this.falsePositives, minIndex, tf.constant(new int[] {1})); + return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + } + + /** + * Gets the sensitivity + * + * @return the sensitivity + */ + public float getSensitivity() { + return sensitivity; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java new file mode 100644 index 00000000000..4312d7a97f0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -0,0 +1,60 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.Reduce; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the (weighted) sum of the given values. + * + *

For example, if values is [1, 3, 5, 7] then the sum is 16. If the + * weights were specified as [1, 1, 0, 0], then the sum would be 4. + * + *

This metric creates one variable, total, that is used to compute the sum of + * values. This is ultimately returned as sum. + * + *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. + * + + */ +public class Sum extends Reduce { + + /** + * Creates a Sum metric with a name of {@link Class#getSimpleName()} + * + * @param tf The TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Sum(Ops tf, long seed, Class type) { + super(tf, null, MetricReduction.SUM, seed, type); + } + + /** + * Creates a Sum metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric instance. If null, defaults to {@link Class#getSimpleName()} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Sum(Ops tf, String name, long seed, Class type) { + super(tf, name, MetricReduction.SUM, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java new file mode 100644 index 00000000000..d2db4f368ac --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Computes the poisson loss metric between labels and predictions. + * + * @param The data type for the metric result + */ +public class TopKCategoricalAccuracy + extends MeanMetricWrapper implements LossMetric { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * top elements to look at for computing accuracy. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Metrics.topKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java new file mode 100644 index 00000000000..de6428fed88 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of true negatives. + * + *

If sampleWeights is given, calculates the sum of the weights of true negatives. + * This metric creates one local variable, accumulator that is used to keep track of + * the number of true negatives. + * + *

If sampleWeightsnull, weights + * default to 1. Use + * sampleWeights of 0 to mask values. + * + * @param The data type for the metric result + */ +public class TrueNegatives + extends ConfusionMatrixConditionCount { + + /** + * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a TrueNegatives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TrueNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a TrueNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.TRUE_NEGATIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java new file mode 100644 index 00000000000..c573b6b5719 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -0,0 +1,128 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of true positives. + * + *

If sampleWeights is given, calculates the sum of the weights of true positives. + * This metric creates one local variable, accumulator that is used to keep track of + * the number of true positives. + * + *

If sampleWeightsnull, weights + * default to 1. Use + * sampleWeights of 0 to mask values. + * @param The data type for the metric result + */ +public class TruePositives + extends ConfusionMatrixConditionCount< T> { + + /** + * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a TruePositives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TruePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a TruePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is true, below is false). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.TRUE_POSITIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java new file mode 100644 index 00000000000..c9e762d05d4 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -0,0 +1,186 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Abstract base class that calculates the value of the given confusion matrix condition based on + * labels and predictions. + * + * @param The data type for the metric result + */ +public abstract class ConfusionMatrixConditionCount extends Metric { + public static final String ACCUMULATOR = "accumulator"; + public static final float DEFAULT_THRESHOLD = 0.5f; + private final ConfusionMatrixEnum confusionMatrixCond; + private final float[] thresholds; + private final String accumulatorName; + private final Class type; + private Variable accumulator; + private Assign initializer; + + /** + * Creates a ConfusionMatrixConditionCount type of Metric, using a threshold of {@link + * #DEFAULT_THRESHOLD} + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param confusionMatrixCond the confusion matrix condition to calculate + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public ConfusionMatrixConditionCount( + Ops tf, String name, ConfusionMatrixEnum confusionMatrixCond, long seed, Class type) { + this(tf, name, confusionMatrixCond, DEFAULT_THRESHOLD, seed, type); + } + /** + * Creates a ConfusionMatrixConditionCount type of Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param confusionMatrixCond the confusion matrix condition to calculate + * @param threshold a threshold value in [0, 1]. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * true, below is false). One metric value is generated for each + * threshold value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public ConfusionMatrixConditionCount( + Ops tf, + String name, + ConfusionMatrixEnum confusionMatrixCond, + float threshold, + long seed, + Class type) { + this(tf, name, confusionMatrixCond, new float[] {threshold}, seed, type); + } + + /** + * Creates a ConfusionMatrixConditionCount type of Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param confusionMatrixCond the confusion matrix condition to calculate + * @param thresholds threshold values in [0, 1]. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * true, below is false). One metric value is generated for each + * threshold value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public ConfusionMatrixConditionCount( + Ops tf, + String name, + ConfusionMatrixEnum confusionMatrixCond, + float[] thresholds, + long seed, + Class type) { + super(tf, name, seed); + accumulatorName = this.getVariableName(ACCUMULATOR); + this.type = type; + this.confusionMatrixCond = confusionMatrixCond; + this.thresholds = thresholds; + init(); + } + + private void init() { + Shape variableShape = Shape.of(this.thresholds.length); + + Zeros zeros = new Zeros<>(getTF()); + accumulator = + getTF() + .withName(getAccumulatorName()) + .variable(zeros.call(getTF().constant(variableShape), type)); + initializer = getTF().assign(accumulator, zeros.call(getTF().constant(variableShape), type)); + } + + /** + * Gets the initializer for the accumulator variable + * + * @return the initializer for the accumulator variable + */ + public Assign getInitializer() { + return initializer; + } + + /** {@inheritDoc} */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Operand tLabels = cast(getTF(), labels, type); + Operand tPredictions = cast(getTF(), predictions, type); + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + return new ArrayList<>( + MetricsHelper.updateConfusionMatrixVariables( + getTF(), + Collections.singletonMap(confusionMatrixCond, accumulator), + Collections.singletonMap(confusionMatrixCond, initializer), + tLabels, + tPredictions, + thresholds, + null, + null, + tSampleWeights, + false, + null)); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + return getTF().identity(accumulator); + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return initializer; + } + + /** + * get the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return this.thresholds; + } + + /** @return the accumulatorName */ + public String getAccumulatorName() { + return accumulatorName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java new file mode 100644 index 00000000000..b76356661a9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +/** Enumerate the values for a confusion matrix. */ +public enum ConfusionMatrixEnum { + /** These are cases in which the prediction is true, and reality is true. */ + TRUE_POSITIVES("tp"), + /** These are cases in which the prediction is false, and reality is true. */ + FALSE_POSITIVES("fp"), + /** These are cases in which the prediction is true, and reality is false. */ + TRUE_NEGATIVES("tn"), + /** These are cases in which the prediction is false, and reality is false. */ + FALSE_NEGATIVES("fn"); + + private final String abbrev; + + /** Creates a ConfusionMatrixEnum */ + ConfusionMatrixEnum(String abbrev) { + this.abbrev = abbrev; + } + + /** + * Gets the ConfusionMatrixEnum for this enum value, regardless of case. + * + * @param item either the name of the enumeration value or the abbreviation. + * @return ConfusionMatrixEnum for this enum value, or null if not found. + */ + public static ConfusionMatrixEnum get(String item) { + ConfusionMatrixEnum cm = valueOf(item.toUpperCase()); + if (cm == null) { + for (ConfusionMatrixEnum m : values()) { + if (m.getAbbreviation().equals(item.toLowerCase())) { + return m; + } + } + } + return null; + } + + /** Gets the abbreviation for this enum value */ + public String getAbbreviation() { + return abbrev; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 8a352322f52..cbb24933967 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,20 +15,26 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; +import org.tensorflow.framework.utils.SparseTensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Stack; +import org.tensorflow.op.core.*; import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.TopK; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; +import java.util.*; +import java.util.concurrent.atomic.AtomicLong; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -212,7 +218,383 @@ public static Operand broadcastWeights( return ctf.math.mul(weights, tf.onesLike(values)); } - // aliases for mean + /** + * Checks that all the Symbolic Shapes are consistent. + * + * @param tf the TensorFlow Ops + * @param symbols the list of Symbolic Shapes + * @param message the error message if the shapes are not consistent. + * @return a list of Operands to check the consistency of the symbolic shapes ready to add to a + * control dependency. + */ + public static List assertShapes( + Ops tf, List> symbols, String message) { + List updateOperations = new ArrayList<>(); + // check that the symbolic shape rank matches the operands rank. + symbols.forEach( + symbol -> { + Operand operand = symbol.getOperand(); + int rank = symbol.rank(); + Rank tfRank = tf.rank(operand); + Op assertion = + tf.withSubScope("assertShapes-1") + .assertThat( + tf.math.equal(tfRank, tf.constant(rank)), + Collections.singletonList(tf.constant(message))); + updateOperations.add(assertion); + }); + + Map dict = new HashMap<>(); + + // check that each operand's dimension size equals the corresponding symbolic shape's dimensions + // size + symbols.forEach( + symbol -> { + AtomicLong ll = new AtomicLong(); + symbol + .getSymbols() + .forEach( + s -> { + Long size = dict.get(s); + if (size == null) { + size = symbol.getOperand().asOutput().shape().size((int) ll.get()); + dict.put(s, size); + } + Op assertion = + tf.withSubScope("assertShapes-2") + .assertThat( + tf.math.equal( + tf.shape.size( + symbol.getOperand(), + tf.constant(ll.getAndIncrement()), + TInt64.class), + tf.constant(size)), + Collections.singletonList(tf.constant(message))); + updateOperations.add(assertion); + }); + }); + + return updateOperations; + } + + /** + * Returns an op to update the given confusion matrix variables. + * + *

For every pair of values in labels and predictions: + * + *

+   * TRUE_POSITIVES:  labels == true and predictions > thresholds
+   * FALSE_POSITIVES: labels == true and predictions <= thresholds
+   * TRUE_NEGATIVES:  labels == false and predictions <= thresholds
+   * FALSE_NEGATIVE:  labels == false and predictions > thresholds
+   * 
+ * + *

The results will be weighted and added together. When multiple thresholds are provided, we + * will repeat the same for every threshold. + * + *

For estimation of these metrics over a stream of data, the function creates an `update_op` + * operation that updates the given variables. + * + *

If sampleWeight is null, weights default to 1. Use weights of 0 to + * mask values. + * + * @param tf the TensorFlow Ops + * @param variablesToUpdate Map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding variables to update as values. + * @param varInitializers Map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding initializer Operands to initializer the corresponding variables from + * variablesToUpdate. + * @param labels the labels, will be cast to {@link TBool} + * @param predictions the predictions whose values are in the range [0, 1]. + * @param thresholds thresholds in `the range [0, 1], or {@link #NEG_INF} (used when + * topK is set) + * @param topK Optional, indicates that the positive labels should be limited to the top k + * predictions, may be null. + * @param classId Optional, limits the prediction and labels to the specified class + * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as + * labels, and must be broadcast to labels (i.e., all dimensions + * must be either 1, or the same as the corresponding labels + * dimension). + * @param multiLabel indicates whether multidimensional prediction/labels should be treated as + * multilabel responses, or flattened into a single label. When true, the values of + * variablesToUpdate must have a second dimension equal to the number of labels and + * predictions, and those tensors must not be RaggedTensors. + * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied + * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES + * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. + * @param the data type for the variables + * @throws IllegalArgumentException If predictions and labels have + * mismatched shapes, or if sampleWeight is not null>and its shape + * doesn't match predictions + * @return an op to update the given confusion matrix variables. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static List updateConfusionMatrixVariables( + Ops tf, + Map> variablesToUpdate, + Map> varInitializers, + Operand labels, + Operand predictions, + float[] thresholds, + Integer topK, + Integer classId, + Operand sampleWeight, + boolean multiLabel, + Operand labelWeights) { + if (multiLabel && labelWeights != null) + throw new IllegalArgumentException( + "labelWeights for multilabel data should be handled outside of updateConfusionMatrixVariables when multiLabel is true."); + + if (variablesToUpdate == null || variablesToUpdate.isEmpty()) { + return Collections.EMPTY_LIST; + } + + Operand lLabels = labels; + Operand lPredictions = predictions; + Operand lSampleWeight = sampleWeight; + + Operand numThresholds; + Operand oneThresh; + if (multiLabel) { + numThresholds = tf.shape.size(lLabels, tf.constant(0)); + oneThresh = tf.math.equal(tf.constant(1), tf.constant(thresholds.length)); + } else { + // TODO handle Ragged Tensors???? + // [y_pred, + // y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], + // sampleWeights) + numThresholds = tf.constant(thresholds.length); + oneThresh = tf.constant(true); + } + + List controlOps = new ArrayList<>(); + Operand axes = allAxes(tf, lPredictions); + controlOps.add( + tf.withSubScope("updateConfusionMatrixVariables-1") + .assertThat( + tf.reduceAll( + tf.math.greaterEqual( + lPredictions, cast(tf, tf.constant(0), lPredictions.type())), + axes), + Collections.singletonList(tf.constant("predictions must be >= 0")))); + controlOps.add( + tf.withSubScope("updateConfusionMatrixVariables-2") + .assertThat( + tf.reduceAll( + tf.math.lessEqual(lPredictions, cast(tf, tf.constant(1), lPredictions.type())), + axes), + Collections.singletonList(tf.constant("predictions must be <= 1")))); + + LossTuple result = + LossesHelper.squeezeOrExpandDimensions(tf, lLabels, lPredictions, lSampleWeight); + lPredictions = result.getTarget(); + lLabels = result.getLabels(); + lSampleWeight = result.getSampleWeights(); + + if (!lPredictions.shape().isCompatibleWith(lLabels.shape())) + throw new IllegalArgumentException( + String.format( + "Shapes %s and %s are incompatible)", + lPredictions.shape().toString(), lLabels.asOutput().shape().toString())); + + if (topK != null) { + lPredictions = filterTopK(tf, lPredictions, topK); + } + + if (classId != null) { + lLabels = tf.squeeze(tf.gather(lLabels, tf.constant(new int[] {classId}), tf.constant(1))); + lPredictions = + tf.squeeze(tf.gather(lPredictions, tf.constant(new int[] {classId}), tf.constant(1))); + lLabels = tf.expandDims(lLabels, tf.constant(0)); + lPredictions = tf.expandDims(lPredictions, tf.constant(0)); + } + org.tensorflow.op.core.Shape predShape = tf.shape(lPredictions); + Operand numPredictions = + tf.reshape(tf.shape.size(lPredictions, tf.constant(0)), tf.constant(Shape.scalar())); + Operand numLabels = + tf.select( + tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), + tf.constant(1), + tf.reduceProd( + tf.shape.takeLast( + predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), + tf.constant(0))); + Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); + + Operand predictionsExtraDim; + Operand labelsExtraDim; + if (multiLabel) { + predictionsExtraDim = tf.expandDims(lPredictions, tf.constant(0)); + labelsExtraDim = tf.expandDims(cast(tf, lLabels, TBool.class), tf.constant(0)); + } else { + predictionsExtraDim = tf.reshape(lPredictions, tf.constant(Shape.of(1, -1))); + labelsExtraDim = tf.reshape(cast(tf, lLabels, TBool.class), tf.constant(Shape.of(1, -1))); + } + List> threshPretileShape; + List> threshTiles; + List> dataTiles; + if (multiLabel) { + threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); + + threshTiles = Arrays.asList(tf.constant(1), numPredictions, threshLabelTile); + dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1)); + } else { + threshPretileShape = Arrays.asList(numThresholds, tf.constant(-1)); + Operand mul = tf.math.mul(numPredictions, numLabels); + threshTiles = Arrays.asList(tf.constant(1), mul); + dataTiles = Arrays.asList(numThresholds, tf.constant(1)); + } + + Operand thresholdsReshaped = + tf.reshape( + cast(tf, tf.constant(thresholds), predictions.type()), tf.stack(threshPretileShape)); + Operand threshTilesShape = tf.stack(threshTiles); + Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); + Operand predsTiled = tf.tile(predictionsExtraDim, tf.stack(dataTiles)); + + // Compare predictions and threshold. + Operand predIsPos = tf.math.greater(predsTiled, threshTiled); + // Tile labels by number of thresholds + Operand labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles)); + Operand weightsTiled; + if (lSampleWeight != null) { + lSampleWeight = + tf.broadcastTo(cast(tf, lSampleWeight, predictions.type()), tf.shape(lPredictions)); + weightsTiled = tf.tile(tf.reshape(lSampleWeight, tf.stack(threshTiles)), tf.stack(dataTiles)); + } else { + weightsTiled = null; + } + + if (labelWeights != null) { + Operand lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0)); + lLabelWeights = tf.broadcastTo(cast(tf, lLabelWeights, labelWeights.type()), lPredictions); + Operand labelWeightsTiled = + tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles)); + if (weightsTiled == null) { + weightsTiled = labelWeightsTiled; + } else { + weightsTiled = tf.math.mul(weightsTiled, labelWeightsTiled); + } + } + + Map loopVars = new HashMap<>(); + loopVars.put(ConfusionMatrixEnum.TRUE_POSITIVES, new Operand[] {labelIsPos, predIsPos}); + Variable update_tn = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES); + Variable update_fp = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES); + Variable update_fn = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES); + + Operand predIsNeg = null; + Operand labelIsNeg; + if (update_fn != null || update_tn != null) { + predIsNeg = tf.math.logicalNot(predIsPos); + loopVars.put(ConfusionMatrixEnum.FALSE_NEGATIVES, new Operand[] {labelIsPos, predIsNeg}); + } + + if (update_fp != null || update_tn != null) { + labelIsNeg = tf.math.logicalNot(labelIsPos); + loopVars.put(ConfusionMatrixEnum.FALSE_POSITIVES, new Operand[] {labelIsNeg, predIsPos}); + if (update_tn != null) { + loopVars.put(ConfusionMatrixEnum.TRUE_NEGATIVES, new Operand[] {labelIsNeg, predIsNeg}); + } + } + + final Operand weightsTiledF = weightsTiled; + loopVars + .keySet() + .forEach( + (c) -> { + if (variablesToUpdate.containsKey(c)) { + Operand[] op = loopVars.get(c); + // op[0] = label, op[1] == prediction + controlOps.add( + weightedAssignAdd( + tf, + op[0], + op[1], + weightsTiledF, + variablesToUpdate.get(c), + varInitializers.get(c))); + } + }); + + return controlOps; + } + + /** + * Creates an Operand that adds the values by taking the logical and of labels and predictions to + * the specified confusion matrix variable. + * + * @param tf The TensorFlow Ops + * @param labels the labels + * @param predictions the predictions + * @param weights the weights applied to the logical and result, may be null + * @param variable the variable to update + * @param initializer the variable initializer to be applied to the variable, may be null. + * @param the data type for the variable. + * @return an Operand that updates the variable. + */ + private static Operand weightedAssignAdd( + Ops tf, + Operand labels, + Operand predictions, + Operand weights, + Variable variable, + Assign initializer) { + Class type = variable.type(); + Operand labelAndPred = cast(tf, tf.math.logicalAnd(labels, predictions), type); + + if (weights != null) { + Operand lWeights = cast(tf, weights, type); + labelAndPred = tf.math.mul(labelAndPred, lWeights); + } + Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); + Operand assignAdd; + if (initializer != null) { + Ops tfc = + tf.withSubScope("weightedAssignAdd") + .withControlDependencies(Collections.singletonList(initializer)); + assignAdd = tfc.assignAdd(variable, valueSum); + } else { + assignAdd = tf.assignAdd(variable, valueSum); + } + return assignAdd; + } + + /** + * Filters top-k values in the last dim of x and set the rest to NEG_INF. + * + *

Used for computing top-k prediction values in dense labels (which has the same shape as + * predictions) for recall and precision top-k metrics. + * + * @param tf The TensorFlow Ops + * @param x the tensor with any dimensions to filter + * @param topK the number of values to keep. + * @param the data type for x and the return value. + * @return the topK prediction values. + */ + private static Operand filterTopK(Ops tf, Operand x, int topK) { + Class type = x.type(); + Shape xShape = x.shape(); + TopK top = tf.nn.topK(x, tf.constant(topK), TopK.sorted(false)); + OneHot oneHot = + tf.oneHot( + top.indices(), + cast(tf, tf.constant(xShape.size(xShape.numDimensions() - 1)), TInt32.class), + tf.constant(1), + tf.constant(0), + OneHot.axis(-1L)); + Operand topKMask = cast(tf, tf.reduceSum(oneHot, tf.constant(-2)), type); + + // x * top_k_mask + NEG_INF * (1 - top_k_mask) + Operand add1 = tf.math.mul(x, topKMask); + Operand add2 = + tf.math.mul( + cast(tf, tf.constant(NEG_INF), type), + tf.math.sub(cast(tf, tf.constant(1), type), topKMask)); + return tf.math.add(add1, add2); + } + + // alias for mean /** * Calculate the mean of the operand, along all axes and keepDims is false @@ -279,6 +661,103 @@ public static Operand mean( return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } + public static + LossTuple raggedAssertCompatibleAndGetFlatValues( + Ops tf, Operand labels, Operand predictions) { + // TODO handle ragged Tensors + Operand tLabels = cast(tf, labels, predictions.type()); + return new LossTuple<>(tLabels, predictions); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + * @param tf the TensorFlow Ops + * @param labels 1-D `Tensor` of real labels for the classification task. + * @param predictions 1-D `Tensor` of predictions for a given classification. + * @param numClasses The possible number of labels the classification task can have. + * @param weights optional weights to be applied to the confusion matrix + * @param type Data type of the confusion matrix. + * @param the type of Operands + * @return A Operand of type type with shape [n, n] + * representing the confusion matrix, where n is the number of possible labels in + * the classification task. + * @throws IllegalArgumentException If both predictions and labels do + * not have compatible shapes, or if weights is notnull and its + * shape is not compatible with predictions. + */ + public static Operand confusionMatrix( + Ops tf, + Operand labels, + Operand predictions, + Operand numClasses, + Operand weights, + Class type) { + if (!predictions.shape().isCompatibleWith(labels.shape())) + throw new IllegalArgumentException( + String.format( + "Prediction shape %s is not compatible with labels shape %s", + predictions.shape().toString(), labels.shape().toString())); + tf = tf.withSubScope("confusionMatrix"); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, predictions, labels, null); + Operand lPredictions = cast(tf, ops.getTarget(), TInt64.class); + Operand lLabels = cast(tf, ops.getLabels(), TInt64.class); + + List labelControls = new ArrayList<>(); + List predictionControls = new ArrayList<>(); + + labelControls.add( + tf.assertThat( + tf.reduceAny(tf.math.greaterEqual(lLabels, tf.constant(0L)), allAxes(tf, lLabels)), + Collections.singletonList(tf.constant("`labels` contains negative values")))); + + predictionControls.add( + tf.assertThat( + tf.reduceAny( + tf.math.greaterEqual(lPredictions, tf.constant(0L)), allAxes(tf, lPredictions)), + Collections.singletonList(tf.constant("`predictions` contains negative values")))); + if (numClasses == null) { + numClasses = + tf.math.maximum( + tf.reduceMax(lPredictions, allAxes(tf, lPredictions)), + tf.reduceMax(lLabels, allAxes(tf, lLabels))); + } else { + labelControls.add( + tf.assertThat( + tf.reduceAny(tf.math.less(lLabels, numClasses), allAxes(tf, lLabels)), + Collections.singletonList(tf.constant("``labels` out of bounds")))); + predictionControls.add( + tf.assertThat( + tf.reduceAny(tf.math.less(lPredictions, numClasses), allAxes(tf, lPredictions)), + Collections.singletonList(tf.constant("``predictions` out of bounds")))); + } + + if (weights != null) { + if (!lPredictions.shape().isCompatibleWith(weights.shape())) { + throw new IllegalArgumentException( + String.format( + "Prediction shape %s is not compatible with weights shape %s", + lPredictions.shape().toString(), weights.shape().toString())); + } + } + + Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls); + lLabels = tfc.identity(lLabels); + + tfc = tf.withSubScope("confusionMatrixPredicitons").withControlDependencies(predictionControls); + lPredictions = tfc.identity(lPredictions); + + Operand shape = tf.stack(Arrays.asList(numClasses, numClasses)); + Operand indices = tf.stack(Arrays.asList(lLabels, lPredictions), Stack.axis(1L)); + Operand values = + weights == null ? cast(tf, tf.onesLike(lPredictions), type) : cast(tf, weights, type); + SparseTensor cmSparse = new SparseTensor<>(indices, values, shape); + Operand zeroMatrix = tf.zeros(shape, type); + + return tf.sparse.sparseTensorDenseAdd( + cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix); + } + /** * Calculate the mean of the operand, along all axes and keepDims is false * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java new file mode 100644 index 00000000000..3949ede822a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -0,0 +1,277 @@ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.*; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Abstract base class for computing sensitivity and specificity. + * + * @param The data type for the metric result + */ +public abstract class SensitivitySpecificityBase extends Metric { + + public static final int DEFAULT_NUM_THRESHOLDS = 200; + + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; + public static final String TRUE_NEGATIVES = "TRUE_NEGATIVES"; + public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; + protected final int numThresholds; + protected final float value; + protected final float[] thresholds; + private final String truePositivesName; + private final String falsePositivesName; + private final String trueNegativesName; + private final String falseNegativesName; + private final Class type; + protected Variable truePositives; + protected Variable falsePositives; + protected Variable trueNegatives; + protected Variable falseNegatives; + + private Assign truePositivesInitializer; + private Assign falsePositivesInitializer; + private Assign trueNegativesInitializer; + private Assign falseNegativesInitializer; + + /** + * Creates a SensitivitySpecificityBase Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric instance, if null then {@link Class#getSimpleName()} is used + * @param value A scalar value in range `[0, 1]` + * @param numThresholds The number of thresholds to use for matching the given recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0. + */ + protected SensitivitySpecificityBase( + Ops tf, String name, float value, int numThresholds, long seed, Class type) { + super(tf, name, seed); + if (numThresholds <= 0) throw new IllegalArgumentException("numThresholds must be > 0."); + this.type = type; + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); + this.trueNegativesName = this.getVariableName(TRUE_NEGATIVES); + this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); + + this.value = value; + this.numThresholds = numThresholds; + + if (this.numThresholds == 1) { + this.thresholds = new float[] {0.5f}; + } else { + this.thresholds = new float[numThresholds]; + for (int i = 0; i < numThresholds - 2; i++) { + this.thresholds[i + 1] = (i + 1f) / (float) (numThresholds - 1); + } + this.thresholds[numThresholds - 1] = 1f; + } + init(); + } + + /** Initializes the Variables */ + private void init() { + Ops tf = getTF(); + Zeros zeros = new Zeros<>(tf); + Shape varShape = Shape.of(numThresholds); + Operand zero = zeros.call(tf.constant(varShape), type); + + if (this.getTruePositives() == null) { + + truePositives = tf.withName(truePositivesName).variable(zero); + truePositivesInitializer = tf.assign(truePositives, zero); + } + if (this.getFalsePositives() == null) { + + falsePositives = tf.withName(falsePositivesName).variable(zero); + falsePositivesInitializer = tf.assign(falsePositives, zero); + } + if (this.getTrueNegatives() == null) { + + trueNegatives = tf.withName(trueNegativesName).variable(zero); + trueNegativesInitializer = tf.assign(trueNegatives, zero); + } + if (this.getFalseNegatives() == null) { + + falseNegatives = tf.withName(falseNegativesName).variable(zero); + falseNegativesInitializer = tf.assign(falseNegatives, zero); + } + } + + public Op initializeVariables() { + List varInitializers = new ArrayList<>(); + + if(truePositivesInitializer != null ) { + varInitializers.add(truePositivesInitializer); + } + if(falsePositivesInitializer != null ) { + varInitializers.add(falsePositivesInitializer); + } + if(trueNegativesInitializer != null ) { + varInitializers.add(trueNegativesInitializer); + } + if(falseNegativesInitializer != null ) { + varInitializers.add(falseNegativesInitializer); + } + + return getTF().withControlDependencies(varInitializers).noOp(); + + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + Ops tf = getTF(); + Operand tLabels = cast(tf, labels, type); + Operand tPredictions = cast(tf, predictions, type); + Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; + + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.getTruePositives()); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.getFalsePositives()); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.getTrueNegatives()); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.getFalseNegatives()); + + return MetricsHelper.updateConfusionMatrixVariables( + tf, + confusionMatrix, + Collections.EMPTY_MAP, + tLabels, + tPredictions, + this.getThresholds(), + null, + null, + tSampleWeights, + false, + null); + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return initializeVariables(); + } + + /** + * Gets the truePositives variable + * + * @return the truePositives + */ + public Variable getTruePositives() { + return truePositives; + } + + /** + * Gets the falsePositives variable + * + * @return the falsePositives truePositives + */ + public Variable getFalsePositives() { + return falsePositives; + } + + /** + * Gets the trueNegatives variable + * + * @return the trueNegatives truePositives + */ + public Variable getTrueNegatives() { + return trueNegatives; + } + + /** + * Gets the falseNegatives variable + * + * @return the falseNegatives truePositives + */ + public Variable getFalseNegatives() { + return falseNegatives; + } + + /** + * Gets the numThresholds + * + * @return the numThresholds + */ + public int getNumThresholds() { + return numThresholds; + } + + /** + * Gets the value + * + * @return the value + */ + public float getValue() { + return value; + } + + /** + * Gets the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return thresholds; + } + + /** + * Gets the truePositives variable name + * + * @return the truePositivesName + */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** + * Gets the falsePositives variable name + * + * @return the falsePositivesName + */ + public String getFalsePositivesName() { + return falsePositivesName; + } + + /** + * Gets the trueNegatives variable name + * + * @return the trueNegativesName + */ + public String getTrueNegativesName() { + return trueNegativesName; + } + + /** + * Gets the falseNegatives variable name + * + * @return the falseNegativesName + */ + public String getFalseNegativesName() { + return falseNegativesName; + } + + /** + * Gets the type + * + * @return the type + */ + public Class getType() { + return type; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java new file mode 100644 index 00000000000..d28185ae041 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class SymbolicShape { + private Operand operand; + private List symbols = new ArrayList<>(); + + public SymbolicShape(Operand operand, String... symbols) { + this.operand = operand; + this.symbols.addAll(Arrays.asList(symbols)); + } + + /** @return the operand */ + public Operand getOperand() { + return operand; + } + + /** @param operand the operand to set */ + public void setOperand(Operand operand) { + this.operand = operand; + } + + /** @return the symbols */ + public List getSymbols() { + return symbols; + } + + /** @param symbols the symbols to set */ + public void setSymbols(List symbols) { + this.symbols = symbols; + } + + public int rank() { + return this.symbols.size(); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java new file mode 100644 index 00000000000..09752798ad5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -0,0 +1,186 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class WeightsBroadcastOps { + + private static final String ASSERT_BROADCASTABLE_ERROR_PREFIX = + "weights can not be broadcast to values."; + + /** + * Asserts that `weights` can be broadcast to `values` + * + * @param tf the TensorFlow Ops + * @param weights `Tensor` of weights. + * @param values `Tensor` of values to which weights are applied. + * @return `Operation` raising `InvalidArgumentError` if `weights` has incorrect shape. `no_op` if + * static checks determine `weights` has correct shape. + * @param the type of weights and values + * @throws IllegalArgumentException If static checks determine `weights` has incorrect shape. + */ + @SuppressWarnings("unchecked") + public static Op assertBroadcastable( + Ops tf, Operand weights, Operand values) { + Operand weightsShape = tf.shape(weights); + Operand weightsRank = tf.rank(weights); + Shape weightsShapeStatic = weights.shape(); + int weightsRankStatic = weightsShapeStatic.numDimensions(); + + Operand valuesShape = tf.shape(values); + Operand valuesRank = tf.rank(values); + Shape valuesShapeStatic = values.asOutput().shape(); + int valuesRankStatic = valuesShapeStatic.numDimensions(); + + if (weightsRankStatic != -1 && valuesRankStatic != -1) { + if (weightsRankStatic == 0) { + return tf.withSubScope("staticScalarCheckSuccess") + .withControlDependencies(Collections.EMPTY_LIST) + .noOp(); + } + if (weightsRankStatic != valuesRankStatic) { + throw new IllegalArgumentException( + String.format( + "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", + ASSERT_BROADCASTABLE_ERROR_PREFIX, + valuesRankStatic, + weightsRankStatic, + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); + } + + for (int i = 0; i < valuesRankStatic; i++) { + if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + throw new IllegalArgumentException( + String.format( + "%s Mismatch at dim %s. values.shape=%s weights.shape=%s.", + ASSERT_BROADCASTABLE_ERROR_PREFIX, + i, + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); + } + } + return tf.withSubScope("staticDimsCheckSuccess") + .withControlDependencies(Collections.EMPTY_LIST) + .noOp(); + } + // Dynamic checks. + Operand is_scalar = tf.math.equal(weightsRank, tf.constant(0)); + List> data = + Arrays.asList( + tf.constant(ASSERT_BROADCASTABLE_ERROR_PREFIX), + tf.constant("weights.shape="), + weightsShape, + tf.constant("values.shape="), + valuesShape, + tf.constant("is_scalar="), + is_scalar); + + Operand isValidShape = + tf.select( + is_scalar, + is_scalar, + hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); + + return tf.assertThat(isValidShape, data); + } + + /** + * Check to see that weights and values have the same rank, if they do, then check each + * corresponding dim of each. + * + * @param tf The TensorFlow Ops + * @param weightsRank the rank operand for the weights + * @param weightsShape the shape operand for the weights + * @param valuesRank the rank operand for the values + * @param valuesShape the shape operand for the values + * @return a boolean Operand, true if both shapes have the same rank, and each dimension is the + * same + */ + private static Operand hasValidNonscalarShape( + Ops tf, + Operand weightsRank, + Operand weightsShape, + Operand valuesRank, + Operand valuesShape) { + tf = tf.withSubScope("hasValidNonscalarShape"); + Operand isSameRank = tf.math.equal(valuesRank, weightsRank); + return tf.select(isSameRank, hasValidDims(tf, weightsShape, valuesShape), isSameRank); + } + + /** + * Checks that each dimension of the two shapes are the same + * + * @param tf the TensorFlow Ops + * @param weightsShape the shape of the weights + * @param valuesShape the shape of the values + * @return a boolean Operand, true if all the dimensions of the two shapes are the same. + */ + private static Operand hasValidDims( + Ops tf, Operand weightsShape, Operand valuesShape) { + tf = tf.withSubScope("hasInvalidDims"); + Operand diff = tf.reduceSum(tf.math.sub(weightsShape, valuesShape), tf.constant(0)); + return tf.math.equal(tf.constant(0), diff); + } + + /** + * Broadcast `weights` to the same shape as `values`. + * + *

This returns a version of weights following the same broadcast rules as + * mul(weights, + * values), but limited to the weights shapes allowed by assertBroadcastable + * When computing a weighted average, use this function to broadcast weights before + * summing them; e.g., reduceSum(w * v) / reduceSum(_broadcast_weights(w, v)). + * + * @param tf the TensorFlow ops + * @param weights `Tensor` whose shape is able to be broadcast to `values` + * @param values Tensor` of any shape + * @param the type of Operand + * @return weights broadcast to values shape + */ + public static Operand broadcastWeights( + Ops tf, Operand weights, Operand values) { + tf = tf.withSubScope("broadcast_weights"); + Operand tValues = cast(tf, values, weights.type()); + + Shape weightsShape = weights.shape(); + Shape valuesShape = tValues.shape(); + + if (!weightsShape.hasUnknownDimension() + && !valuesShape.hasUnknownDimension() + && weightsShape.isCompatibleWith(valuesShape)) { + return weights; + } + + Op dependencies = assertBroadcastable(tf, weights, tValues); + Ops tf1 = + tf.withSubScope("assertBroadcastable") + .withControlDependencies(Collections.singletonList(dependencies)); + return tf1.math.mul(weights, tf.onesLike(tValues)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java new file mode 100644 index 00000000000..9dee070eea9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java @@ -0,0 +1,75 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.utils; + +import org.tensorflow.Operand; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * This is a helper class that represents a sparse tensor who's attributes may be passed to + * {@link org.tensorflow.op.Ops#sparse} methods. + * + * @param the type of the SparseTensor + */ +public class SparseTensor { + private final Operand indices; + private final Operand values; + private final Operand denseShape; + + /** + * Creates a SparseTensor + * + * @param indices A 2-D int64 tensor of shape `[N, ndims]`, which specifies the + * indices of the elements in the sparse tensor that contain nonzero values + * @param values A 1-D tensor of any type and shape `[N]`, which supplies the + * values for each element in `indices`. + * @param denseShape A 1-D int64 tensor of shape `[ndims]`, which specifies the + * dense_shape of the sparse tensor + * @throws IllegalArgumentException When building an eager SparseTensor if `dense_shape` is + * unknown or contains unknown elements (None or -1). + */ + public SparseTensor (Operand indices, Operand values, Operand denseShape) { + this.indices = indices; + this.values = values; + this.denseShape = denseShape; + } + + /** + * Gets the indices for the Sparse Tensor + * @return the indices + */ + public Operand getIndices() { + return indices; + } + + /** + * Get the values for the Sparse Tensor + * @return the values + */ + public Operand getValues() { + return values; + } + + /** + * Gets the dense shape for the Sparse Tensor + * + * @return the denseShape + */ + public Operand getDenseShape() { + return denseShape; + } + +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java new file mode 100644 index 00000000000..88825b5f32e --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -0,0 +1,324 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +import static org.junit.jupiter.api.Assertions.*; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class AUCTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float epsilon = 1e-4F; + + int numThresholds = 3; + float[] predArray = new float[] {0f, 0.5f, 0.3f, 0.9f}; + int[] trueArray = new int[] {0, 0, 1, 1}; + float[] sampleWeight = new float[] {1, 2, 3, 4}; + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(predArray); + Operand yTrue = tf.constant(trueArray); + AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + + session.run(tf.init()); + + Op update = instance.updateState(yTrue, yPred, null); + + for (int i = 0; i < 10; i++) { + session.run(update); + } + + Operand result = instance.result(); + + for (int i = 0; i < 10; i++) { + session.evaluate(result, instance.result()); + } + } + } + + @Test + public void basicTestSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + assertEquals(numThresholds, instance.getNumThresholds()); + float[] expectedThresholds = new float[] {-1e-7f, 0.5f, 1 + 1e-7f}; + assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); + + instance.resetStates(); + Operand yPred = tf.constant(new float[] {0, 0, 1, 1}); + Operand yTrue = tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}); + Operand sampleWeights = tf.constant(new float[] {1, 0, 0, 1}); + + Op update = instance.updateState(yTrue, yPred, sampleWeights); + session.run(update); + Operand result = instance.result(); + session.evaluate(1.0f, result); + } + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yTrue = cast(tf, tf.constant(this.trueArray), TFloat32.class); + AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + session.run(tf.init()); + + Op update = instance.updateState(yTrue, yTrue, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(1f, result); + } + } + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + Operand result = instance.result(); + + // float expectedResult = (0.75f * 1 + 0.25f * 0); + session.evaluate(0.75f, result); + } + } + + @Test + public void testManualThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + AUC instance = new AUC<>(tf, new float[] {0.5f}, 1001L, TFloat32.class); + float[] expectedThresholds = new float[] {-AUC.EPSILON, 0.5f, 1 + AUC.EPSILON}; + assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + Operand result = instance.result(); + + // float expectedResult = (0.75f * 1 + 0.25f * 0); + session.evaluate(0.75f, result); + } + } + + @Test + public void testWeightedRocInterpolation() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + + float expectedResult = (0.78571427f * 1 + 0.2857145f * 0); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedRocMajoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.ROC, + AUCSummationMethod.MAJORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + + float expectedResult = (1.0f + .5714285f * 0f); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedRocMinoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.ROC, + AUCSummationMethod.MINORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + + float expectedResult = ( 0.5714285f + 0f * 0f); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedPrMajoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.PR, + AUCSummationMethod.MAJORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.4285715f + 0.5714285f; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedPrMinoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.PR, + AUCSummationMethod.MINORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.7f * 0.4285715f + 0f * 0.5714285f; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedPrInterpolation() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>(tf, this.numThresholds, AUCCurve.PR, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.916613f; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + new AUC<>(tf, -1, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testExtraDims() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // logits = scipy.special.expit(-np.array([[[-10., 10., -10.], [10., -10., 10.]], + // [[-12., 12., -12.], [12., -12., 12.]]], + // dtype=np.float32)) + float[][][] logitsArray = { + { + {9.99954602e-01f, 4.53978687e-05f, 9.99954602e-01f}, + {4.53978687e-05f, 9.99954602e-01f, 4.53978687e-05f} + }, + { + {9.99993856e-01f, 6.14417460e-06f, 9.99993856e-01f}, + {6.14417460e-06f, 9.99993856e-01f, 6.14417460e-06f} + } + }; + + long[][][] labelArray = { + {{1, 0, 0}, {1, 0, 0}}, + {{0, 1, 1}, {0, 1, 1}} + }; + + Operand logits = tf.constant(logitsArray); + Operand labels = tf.constant(labelArray); + + AUC instance = new AUC<>(tf, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(labels, logits, null); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.5f; + session.evaluate(expectedResult, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java new file mode 100644 index 00000000000..48cac95b8a6 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class AccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2, 3, 4}; + float[] predArray = {1, 2, 3, 4}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 1))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4F, total); + session.evaluate(4, count); + session.evaluate(1F, result); + } + } + + @Test + public void testSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {2, 1}; + float[] predArray = {2, 0}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(.5F, total); + session.evaluate(.7, count); + session.evaluate(0.71428573f, result); + } + } + + @Test + public void testVariableState() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {2, 1}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // 2nd run + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(1.4F, total); + session.evaluate(1.4, count); + session.evaluate(1.0F, result); + + // new instance same graph + instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + total = instance.getTotal(); + count = instance.getCount(); + result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // reset variables + session.run(instance.resetStates()); + result = instance.result(); + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java new file mode 100644 index 00000000000..e8d8350dcdc --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java @@ -0,0 +1,177 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class BinaryAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 0}; + float[] predArray = {1, 0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(2, count); + session.evaluate(1F, result); + } + } + + @Test + public void testPredictionSqueeze() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 0}; + float[] predArray = {1, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1, 1))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(4, count); + session.evaluate(0.5F, result); + } + } + + @Test + public void testSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 1}; + float[] predArray = {1, 0}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.5F, total); + session.evaluate(.7, count); + session.evaluate(0.71428573f, result); + } + } + + @Test + public void testVariableState() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {2, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.2F, total); + session.evaluate(.7, count); + session.evaluate(0.2857143F, result); + + // 2nd run + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(0.4F, total); + session.evaluate(1.4, count); + session.evaluate(0.2857143F, result); + + // new instance same graph + instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + total = instance.getTotal(); + count = instance.getCount(); + result = instance.result(); + session.evaluate(0.2F, total); + session.evaluate(.7, count); + session.evaluate(0.2857143F, result); + + // reset variables + session.run(instance.resetStates()); + session.evaluate(0.0, total); + session.evaluate(0.0, count); + + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(0.2F, total); + session.evaluate(.7, count); + session.evaluate(0.2857143F, result); + } + } + + @Test + public void testBinaryAccuracyAThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = + new BinaryAccuracy<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 1, 0, 0}; + float[] predArray = {0.9f, 0.6f, 0.4f, 0.8f}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 1))); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(4, count); + session.evaluate(0.5F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java new file mode 100644 index 00000000000..83990cbaebb --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java @@ -0,0 +1,156 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class CategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalAccuracy instance = + new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 0, 1, + 0, 1, 0 + }; + float[] predArray = { + 0.1f, 0.1f, 0.8f, + 0.05f, 0.95f, 0f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(2, count); + session.evaluate(1F, result); + } + } + + @Test + public void testSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalAccuracy instance = + new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 0, 1, + 0, 1, 0 + }; + float[] predArray = { + 0.1f, 0.1f, 0.8f, + 0.05f, 0.95f, 0f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + } + } + + @Test + public void testVariableState() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalAccuracy instance = + new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 0, 1, + 0, 1, 0 + }; + float[] predArray = { + 0.1f, 0.1f, 0.8f, + 0.05f, 0.95f, 0f + }; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // 2nd run + op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(1.4F, total); + session.evaluate(1.4, count); + session.evaluate(1.0F, result); + + // new instance same graph + instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + result = instance.result(); + total = instance.getTotal(); + count = instance.getCount(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // reset variables + session.run(instance.resetStates()); + session.evaluate(0, total); + session.evaluate(0, count); + session.evaluate(0, result); + + op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java new file mode 100644 index 00000000000..4bd8d99586e --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class FalseNegativesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + FalseNegatives instance = new FalseNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(3.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + FalseNegatives instance = new FalseNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(5.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + FalseNegatives instance = + new FalseNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {1.f, 4.f, 6.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = tf.constant(new double[][] {{3.0}, {5.0}, {7.0}, {4.0}}); + FalseNegatives instance = + new FalseNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {4., 16., 23.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java new file mode 100644 index 00000000000..2584c7a3244 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java @@ -0,0 +1,148 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class FalsePositivesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + FalsePositives instance = new FalsePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(7.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + FalsePositives instance = new FalsePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(14.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + FalsePositives instance = + new FalsePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {7.f, 4.f, 2.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = + tf.constant( + new double[][] { + {1.0, 2.0, 3.0, 5.0}, + {7.0, 11.0, 13.0, 17.0}, + {19.0, 23.0, 29.0, 31.0}, + {19.0, 23.0, 29.0, 31.0} + }); + FalsePositives instance = + new FalsePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {125., 42., 12.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java new file mode 100644 index 00000000000..fc08455d1c7 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +public class MeanIoUTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final long numClasses = 2L; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testUnweighted"); + Operand predictions = tf.constant(new long[] {0, 1, 0, 1}); + Operand labels = tf.constant(new long[] {0, 0, 1, 1}); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + double expected_result = (1. / (2. + 2. - 1.) + 1. / (2. + 2. - 1.)) / 2.; + session.evaluate(expected_result, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testWeighted"); + Operand predictions = tf.constant(new long[] {0, 1, 0, 1}); + Operand labels = tf.constant(new long[] {0, 0, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {0.2f, 0.3f, 0.4f, 0.1f}); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + float expected_result = (0.2f / (0.6f + 0.5f - 0.2f) + 0.1f / (0.4f + 0.5f - 0.1f)) / 2f; + session.evaluate(expected_result, result); + } + } + + @Test + public void testMultiDimInput() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testMultiDimInput"); + + Operand predictions = tf.constant(new long[][] {{0, 1}, {0, 1}}); + Operand labels = tf.constant(new long[][] {{0, 0}, {1, 1}}); + Operand sampleWeight = tf.constant(new float[][] {{0.2f, 0.3f}, {0.4f, 0.1f}}); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + float expected_result = (0.2f / (0.6f + 0.5f - 0.2f) + 0.1f / (0.4f + 0.5f - 0.1f)) / 2f; + session.evaluate(expected_result, result); + } + } + + @Test + public void testZeroValidEntries() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testZeroValidEntries"); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Operand result = instance.result(); + session.evaluate(0.0f, result); + } + } + + @Test + public void testZeroAndNonZeroEntries() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testZeroAndNonZeroEntries"); + Operand predictions = tf.constant(new float[] {1}); + Operand labels = tf.constant(new int[] {1}); + + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float expected_result = (0f + 1f / (1f + 1f - 1f)) / 1f; + session.evaluate(expected_result, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java new file mode 100644 index 00000000000..0bb9392b8b0 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class MeanRelativeErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] predArray = new float[][] {{2, 4, 6, 8}}; + float[][] trueArray = new float[][] {{1, 3, 2, 3}}; + Operand predictions = tf.constant(predArray); + Operand labels = tf.constant(trueArray); + + MeanRelativeError instance = + new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + double expected_result = 1.25; + session.evaluate(expected_result, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[] predArray = new float[] {2, 4, 6, 8}; + float[] trueArray = new float[] {1, 3, 2, 3}; + float[] sampleWeightArray = new float[] {0.2f, 0.3f, 0.5f, 0f}; + + Operand predictions = tf.constant(predArray); + Operand labels = tf.constant(trueArray); + Operand sampleWeight = tf.constant(sampleWeightArray); + + MeanRelativeError instance = + new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + double expectedResult = 1.3; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testZeroNormalizer() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[] predArray = new float[] {2, 4}; + int[] trueArray = new int[] {1, 3}; + + Operand predictions = tf.constant(predArray); + Operand labels = tf.constant(trueArray); + + MeanRelativeError instance = + new MeanRelativeError<>( + tf, cast(tf, tf.zerosLike(labels), TFloat32.class), 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + double expectedResult = 0; + session.evaluate(expectedResult, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java new file mode 100644 index 00000000000..ce473bbdf34 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class MeanTensorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand values = tf.constant(new long[] {100, 40}); + MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); + session.run(tf.init()); + Op update = instance.updateState(values, null); + session.run(update); + Operand result = instance.result(); + double[] expected_result = new double[] {100, 40}; + session.evaluate(expected_result, result); + + session.evaluate(expected_result, instance.getTotal()); + session.evaluate(new double[] {1, 1}, instance.getCount()); + + session.run(instance.resetStates()); + session.evaluate(new double[] {0, 0}, instance.getTotal()); + session.evaluate(new double[] {0, 0}, instance.getCount()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand values = tf.constant(new long[] {100, 30}); + MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); + session.run(tf.init()); + + // check scalar weight + Op update = instance.updateState(values, tf.constant(0.5f)); + session.run(update); + Operand result = instance.result(); + double[] expected_result = new double[] {100, 30}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {50, 15}, instance.getTotal()); + session.evaluate(new double[] {0.5, 0.5}, instance.getCount()); + + // check weights not scalar and weights rank matches values rank + values = tf.constant(new long[] {1, 5}); + update = instance.updateState(values, tf.constant(new double[] {1f, 0.2f})); + session.run(update); + result = instance.result(); + expected_result = new double[] {51 / 1.5, 16 / 0.7}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {51, 16}, instance.getTotal()); + session.evaluate(new double[] {1.5, .7}, instance.getCount()); + + // check weights broadcast + values = tf.constant(new long[] {1, 2}); + update = instance.updateState(values, tf.constant(0.5f)); + session.run(update); + result = instance.result(); + expected_result = new double[] {51.5 / 2, 17 / 1.2}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {51.5, 17}, instance.getTotal()); + session.evaluate(new double[] {2, 1.2}, instance.getCount()); + + // check weights squeeze + values = tf.constant(new long[] {1, 5}); + Operand sampleWeight = tf.constant(new double[][] {{1}, {0.2}}); + update = instance.updateState(values, sampleWeight); + session.run(update); + result = instance.result(); + expected_result = new double[] {52.5 / 3, 18 / 1.4}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {52.5, 18}, instance.getTotal()); + session.evaluate(new double[] {3, 1.4}, instance.getCount()); + } + } + + @Test + public void testWeightedExpand() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + // check weights expand + MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat32.class); + + Operand values = tf.constant(new long[][] {{1}, {5}}); + Operand sampleWeight = tf.constant(new float[] {1f, 0.2f}); + Op update = instance.updateState(values, sampleWeight); + session.run(update); + Operand result = instance.result(); + session.evaluate(tf.constant(new float[][] {{1f}, {5f}}), result); + session.evaluate(tf.constant(new float[][] {{1f}, {1f}}), instance.getTotal()); + session.evaluate(tf.constant(new float[][] {{1f}, {0.2f}}), instance.getCount()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java new file mode 100644 index 00000000000..a817a3dc5df --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class PrecisionAtRecallTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialPrecision = instance.result(); + + for (int i = 0; i < 10; i++) session.evaluate(initialPrecision, instance.result()); + } + } + + private int[][] generateRandomArray(int dim1, int dim2) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(2); + } + } + + return result; + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighRecall() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.8f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.5f, 0.4f, 0.5f, 0.6f, 0.8f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.8f, precision); + } + } + + @Test + public void testUnweightedLowRecall() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.15f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.5f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {2, 2, 1, 1, 1, 1, 1, 2, 2, 2}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(2.f / 3.f, precision); + } + } + + @Test + public void testInvalidSensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new PrecisionAtRecall<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new PrecisionAtRecall<>(tf, 0.7f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java new file mode 100644 index 00000000000..35962a568ca --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java @@ -0,0 +1,339 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +public class PrecisionTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Precision instance = + new Precision<>(tf, new float[] {0.3f, 0.72f}, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L)); + + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) { + session.run(update); + } + + Operand initialPrecision = instance.result(); + + for (int i = 0; i < 10; i++) { + session.evaluate(initialPrecision, instance.result()); + } + } + } + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new long[][] {{1, 0, 1, 0}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(0.5, precision); + } + } + + @Test + public void testUnweightedAllIncorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 0.5f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniformInt(tf.constant(Shape.of(100, 1)), tf.constant(0), tf.constant(2)); + Operand labels = tf.math.sub(tf.constant(1), predictions); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(0.0f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new long[][] {{1, 0, 1, 0}, {1, 0, 1, 0}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}, {1, 0, 0, 1}}); + Operand sampleWeight = tf.constant(new double[][] {{1, 2, 3, 4}, {4, 3, 2, 1}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand precision = instance.result(); + + double weightedTP = 3.0f + 4.0f; + double weightedPositives = (1.0f + 3.0f) + (4.0f + 2.0f); + double expectedPrecision = weightedTP / weightedPositives; + + session.evaluate(expectedPrecision, precision); + } + } + + @Test + public void testDivByZero() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new int[] {0, 0, 0, 0}); + Operand labels = tf.constant(new int[] {0, 0, 0, 0}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(0, precision); + } + } + + @Test + public void testUnweightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = + new Precision<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1f, 0f, 0.6f, 0f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + float[] expected = new float[] {0.5f, 0.f}; + + session.evaluate(expected, precision); + } + } + + @Test + public void testWeightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = + new Precision<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1f, 0f}, {0.6f, 0f}}); + Operand labels = tf.constant(new long[][] {{0, 1}, {1, 0}}); + Operand sampleWeight = tf.constant(new float[][] {{4, 0}, {3, 1}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand precision = instance.result(); + + float weightedTP = 0f + 3.f; + float weightedPositives = (0f + 3.f) + (4.f + 0.f); + float expectedPrecision = weightedTP / weightedPositives; + + Float[] expected = new Float[] {expectedPrecision, 0f}; + session.evaluate(expected, precision); + } + } + + @Test + public void testMultipleUpdates() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = + new Precision<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1f, 0f}, {0.6f, 0f}}); + Operand labels = tf.constant(new long[][] {{0, 1}, {1, 0}}); + Operand sampleWeight = tf.constant(new double[][] {{4, 0}, {3, 1}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + for (int i = 0; i < 2; i++) session.run(update); + Operand precision = instance.result(); + + double weighted_tp = (0 + 3.) + (0 + 3.); + double weighted_positives = ((0 + 3.) + (4. + 0.)) + ((0 + 3.) + (4. + 0.)); + double expected_precision = weighted_tp / weighted_positives; + + double[] expected = new double[] {expected_precision, 0f}; + session.evaluate(expected, precision); + } + } + + @Test + public void testUnweightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK to 3 + Precision instance = + new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(1.0f / 3.0f, precision); + } + } + + @Test + public void testWeightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK to 3 + Precision instance = + new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[] {0.2f, 0.1f, 0.4f, 0f, 0.2f}); + Operand labels = tf.constant(new long[] {0, 1, 1, 0, 1}); + Operand sampleWeight = tf.constant(new float[][] {{1, 4, 2, 3, 5}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.4f, 0.2f, 0.2f}}); + labels = tf.constant(new long[][] {{1, 0, 1, 1, 1}}); + update = instance.updateState(labels, predictions, tf.constant(3.f)); + session.run(update); + + Operand precision = instance.result(); + + float tp = (2f + 5f) + (3f + 3f); + float predicted_positives = (1f + 2f + 5f) + (3f + 3f + 3f); + float expected_precision = tp / predicted_positives; + session.evaluate(expected_precision, precision); + } + } + + @Test + public void testUnweightedClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set classId to 2 + Precision instance = + new Precision<>(tf, null, null, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0f, 0f, 0.2f}}); + labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + labels = tf.constant(new long[][] {{0, 1, 0, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + precision = instance.result(); + + session.evaluate(0.5f, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(1, instance.getFalsePositives()); + } + } + + @Test + public void testUnweightedTopKAndClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK and classId to 2 + Precision instance = + new Precision<>(tf, null, 2, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + + predictions = tf.constant(new float[][] {{1f, 1f, 0.9f, 1f, 1f}}); + labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + } + } + + @Test + public void testUnweightedTopKAndThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK to 2 + Precision instance = + new Precision<>(tf, 0.7f, 2, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 1}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java new file mode 100644 index 00000000000..bd3a5273668 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java @@ -0,0 +1,207 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class RecallAtPrecisionTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + labels = tf.math.mul(labels, tf.constant(2.0f)); + + Op update = instance.updateState(labels, predictions); + + for (int i = 0; i < 10; i++) { + session.run(update); + } + + Operand initialPrecision = instance.result(); + + for (int i = 0; i < 10; i++) { + session.evaluate(initialPrecision, instance.result()); + } + } + } + + private int[][] generateRandomArray(int dim1, int dim2, int maxVal) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(maxVal); + } + } + + return result; + } + + @Test + public void test_unweighted_all_correct() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1, 2); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighPrecision() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.75f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] { + 0.05f, 0.1f, 0.2f, 0.3f, 0.3f, 0.35f, 0.4f, 0.45f, 0.5f, 0.6f, 0.9f, 0.95f + }); + Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.5f, precision); + } + } + + @Test + public void testUnweightedLowPrecision() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 2.0f / 3f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] { + 0.05f, 0.1f, 0.2f, 0.3f, 0.3f, 0.35f, 0.4f, 0.45f, 0.5f, 0.6f, 0.9f, 0.95f + }); + Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(5.f / 6f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.75f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.1f, 0.2f, 0.3f, 0.5f, 0.6f, 0.9f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {1, 2, 1, 2, 1, 2, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.6f, precision); + } + } + + @Test + public void testUnachievablePrecision() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 2.0f / 3f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = tf.constant(new float[] {0.1f, 0.2f, 0.3f, 0.9f}); + Operand labels = tf.constant(new long[] {1, 1, 0, 0}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + // The highest possible precision is 1/2 which is below the required + session.evaluate(0f, precision); + } + } + + @Test + public void test_invalid_sensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new RecallAtPrecision<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void test_invalid_num_thresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new RecallAtPrecision<>(tf, 0.7f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java new file mode 100644 index 00000000000..b9d067a6ed2 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java @@ -0,0 +1,341 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Random; + +public class RecallTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, new float[] {0.3f, 0.72f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class); + Operand labels = + tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class); + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialRecall = instance.result(); + for (int i = 0; i < 10; i++) session.evaluate(initialRecall, instance.result()); + } + } + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1, 0, 1, 0}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.5f, instance.result()); + } + } + + private int[][] generateRandomArray(int dim1, int dim2, int maxInt) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(maxInt); + } + } + + return result; + } + + @Test + public void testUnweightedAllIncorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] array = generateRandomArray(100, 1, 2); + Operand predictions = tf.dtypes.cast(tf.constant(array), TFloat32.class); + Operand labels = + tf.dtypes.cast(tf.math.sub(tf.constant(1), tf.constant(array)), TFloat32.class); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.f, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[][] { + {1, 0, 1, 0}, + {0, 1, 0, 1} + }); + Operand labels = + tf.constant( + new float[][] { + {0, 1, 1, 0}, + {1, 0, 0, 1} + }); + + Operand sampleWeights = + tf.constant( + new float[][] { + {1, 2, 3, 4}, + {4, 3, 2, 1} + }); + Op update = instance.updateState(labels, predictions, sampleWeights); + session.run(update); + + float weightedTp = 3.0f + 1.0f; + float weightedT = (2.0f + 3.0f) + (4.0f + 1.0f); + float expectedRecall = weightedTp / weightedT; + + session.evaluate(expectedRecall, instance.result()); + } + } + + @Test + public void testDivByZero() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[] {0, 0, 0, 0}); + Operand labels = tf.constant(new float[] {0, 0, 0, 0}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0f, instance.result()); + } + } + + @Test + public void testUnweightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1, 0, 0.6f, 0}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Float[] expected = new Float[] {0.5f, 0f}; + session.evaluate(expected, instance.result()); + } + } + + @Test + public void testWeightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); + Operand predictions = tf.constant(new float[][] {{1, 0}, {0.6f, 0}}); + Operand weights = tf.constant(new float[][] {{1, 4}, {3, 2}}); + + Op update = instance.updateState(labels, predictions, weights); + session.run(update); + + float weightedTp = 0 + 3.f; + float weightedPositives = (0 + 3.f) + (4.f + 0.f); + float expectedRecall = weightedTp / weightedPositives; + float[] expected = new float[] {expectedRecall, 0f}; + session.evaluate(expected, instance.result()); + } + } + + @Test + public void testMultipleUpdates() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); + Operand predictions = tf.constant(new float[][] {{1, 0}, {0.6f, 0}}); + Operand weights = tf.constant(new float[][] {{1, 4}, {3, 2}}); + + Op update = instance.updateState(labels, predictions, weights); + for (int i = 0; i < 2; i++) session.run(update); + + float weightedTp = (0f + 3.f) + (0f + 3.f); + float weightedPositives = ((0f + 3.f) + (4.f + 0.f)) + ((0f + 3.f) + (4.f + 0.f)); + float expectedRecall = weightedTp / weightedPositives; + float[] expected = new float[] {expectedRecall, 0f}; + session.evaluate(expected, instance.result()); + } + } + + @Test + public void testUnweightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0f, 1f, 1f, 0f, 0f}}); + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.5f, instance.result()); + } + } + + @Test + public void testWeightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 1}}); + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.4f, 0f, 0.2f}}); + Operand weights = tf.constant(new float[][] {{1, 4, 2, 3, 5}}); + + Op update = instance.updateState(labels, predictions, weights); + session.run(update); + + labels = tf.constant(new float[][] {{1, 0, 1, 1, 1}}); + predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.4f, 0.2f, 0.2f}}); + weights = tf.constant(3.f); + + update = instance.updateState(labels, predictions, weights); + session.run(update); + + float weightedTp = (2 + 5) + (3 + 3); + float weightedPositives = (4 + 2 + 5) + (3 + 3 + 3 + 3); + float expectedRecall = weightedTp / weightedPositives; + session.evaluate(expectedRecall, instance.result()); + } + } + + @Test + public void testUnweightedClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, null, null, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(0f, instance.getFalseNegatives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0f, 0f, 0.2f}}); + labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(1f, instance.getFalseNegatives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + labels = tf.constant(new float[][] {{0, 1, 0, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(1f, instance.getFalseNegatives()); + } + } + + @Test + public void testUnweightedTopKAndClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, null, 2, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0, 0.2f}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(1f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(0f, instance.getFalseNegatives()); + + predictions = tf.constant(new float[][] {{1, 1, 0.9f, 1, 1}}); + labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(1f, instance.getFalseNegatives()); + } + } + + @Test + public void testUnweightedTopKAndThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, null, 0.7f, 2, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new float[][] {{1, 1, 1, 0, 1}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.25f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(3f, instance.getFalseNegatives()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java new file mode 100644 index 00000000000..c9ced9f5946 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +public class RootMeanSquaredErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RootMeanSquaredError instance = + new RootMeanSquaredError<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[] {2, 4, 6}); + Operand predictions = tf.constant(new float[] {1, 3, 2}); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(18, total); + session.evaluate(3, count); + session.evaluate(Math.sqrt(6), result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RootMeanSquaredError instance = + new RootMeanSquaredError<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand labels = tf.constant(new float[][] {{2, 4, 6, 8}}); + Operand predictions = tf.constant(new float[][] {{1, 3, 2, 3}}); + Operand sampleWeight = tf.constant(new double[][] {{0, 1, 0, 1}}); + + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(26, total); + session.evaluate(2, count); + session.evaluate(Math.sqrt(13), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java new file mode 100644 index 00000000000..a65dc3b53da --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java @@ -0,0 +1,185 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class SensitivityAtSpecificityTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + labels = tf.math.mul(labels, tf.constant(2.0f)); + + // instance.setDebug(session.getGraphSession()); + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialSensitivity = instance.result(); + + for (int i = 0; i < 10; i++) session.evaluate(initialSensitivity, instance.result()); + + // instance.setDebug(null); + + } + } + + private int[][] generateRandomArray(int dim1, int dim2) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(2); + } + } + + return result; + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighSpecificity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.8f, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.45f, 0.5f, 0.8f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.8, precision); + } + } + + @Test + public void testUnweightedLowSpecificity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.6f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.4f, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + Operand sampleWeight = tf.constant(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.675, precision); + } + } + + @Test + public void testInvalidSensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SensitivityAtSpecificity<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SensitivityAtSpecificity<>(tf, 0.7f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java new file mode 100644 index 00000000000..ff5834eda8e --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java @@ -0,0 +1,184 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class SpecificityAtSensitivityTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + + // instance.setDebug(session.getGraphSession()); + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialSpecificity = instance.result(); + + for (int i = 0; i < 10; i++) session.evaluate(initialSpecificity, instance.result()); + } + } + + private int[][] generateRandomArray(int dim1, int dim2) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(2); + } + } + + return result; + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = tf.constant(trueArray); + labels = tf.math.mul(labels, tf.constant(2)); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighSensitivity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.8f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.45f, 0.5f, 0.8f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.4f, precision); + } + } + + @Test + public void testUnweightedLowSensitivity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.4f, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.6f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.4f, precision); + } + } + + @Test + public void testInvalidSensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SpecificityAtSensitivity<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SpecificityAtSensitivity<>(tf, 0.4f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java new file mode 100644 index 00000000000..941f882b8c8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java @@ -0,0 +1,113 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SumTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + /** Test of call method, of class Sum. */ + @Test + public void testUnWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Sum instance = new Sum<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + assertEquals(TFloat32.class, instance.getResultType()); + session.evaluate(0f, instance.getTotal()); + + Op update = instance.updateState(tf.constant(100f), null); + session.run(update); + session.evaluate(100f, instance.result()); + session.evaluate(100f, instance.getTotal()); + + update = instance.updateState(tf.constant(new float[] {1, 5}), null); + session.run(update); + session.evaluate(106f, instance.result()); + session.evaluate(106f, instance.getTotal()); + + session.run(instance.resetStates()); + session.evaluate(0f, instance.getTotal()); + } + } + + @Test + public void testSumWithSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Sum instance = new Sum<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + // check scalar weight + Op op = instance.updateState(tf.constant(100f), tf.constant(0.5)); + session.run(op); + Operand result = instance.result(); + session.evaluate(50.0, instance.getTotal()); + session.evaluate(50.0, result); + + // check weights not scalar and weights rank matches values rank + op = + instance.updateState(tf.constant(new float[] {1, 5}), tf.constant(new double[] {1, 0.2})); + session.run(op); + result = instance.result(); + session.evaluate(52., instance.getTotal()); + session.evaluate(52., result); + + // check weights broadcast + op = instance.updateState(tf.constant(new float[] {1, 2}), tf.constant(0.5)); + session.run(op); + result = instance.result(); + session.evaluate(53.5, instance.getTotal()); + session.evaluate(53.5, result); + + // check weights squeeze + op = + instance.updateState( + tf.constant(new float[] {1, 5}), tf.constant(new double[][] {{1}, {0.2}})); + session.run(op); + result = instance.result(); + session.evaluate(55.5, instance.getTotal()); + session.evaluate(55.5, result); + + // check weights expand + op = + instance.updateState( + tf.constant(new float[][] {{1}, {5}}), tf.constant(new double[] {1, 0.2})); + session.run(op); + result = instance.result(); + session.evaluate(57.5, instance.getTotal()); + session.evaluate(57.5, result); + + // heck values reduced to the dimensions of weight + op = + instance.updateState( + tf.constant(new float[][][] {{{1.f, 2.f}, {3.f, 2.f}, {0.5f, 4.f}}}), + tf.constant(new double[] {0.5})); + session.run(op); + result = instance.result(); + session.evaluate(63.75, instance.getTotal()); + session.evaluate(63.75, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java new file mode 100644 index 00000000000..023796ba367 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java @@ -0,0 +1,103 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class TopKCategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrectness() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand labels = tf.constant(new float[][] {{0, 0, 1}, {0, 1, 0}}); + Operand predictions = + tf.constant(new double[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1., instance.result()); + + // With `k` < 5. + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted1", 1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + + // With `k` > 5. + labels = + tf.constant( + new float[][] { + {0, 0, 1, 0, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0} + }); + predictions = + tf.constant( + new double[][] { + {0.5f, 0.9f, 0.1f, 0.7f, 0.6f, 0.5f, 0.4f}, + {0.05f, 0.95f, 0f, 0f, 0f, 0f, 0f} + }); + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted6", 6, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testWeighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand labels = + tf.constant( + new double[][] { + {1, 0, 2}, + {1, 0, 0}, + {0, 0, 1} + }); + Operand predictions = + tf.constant( + new double[][] { + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f} + }); + + Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + session.evaluate(1., instance.result()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java new file mode 100644 index 00000000000..1a68c2ed8b8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class TrueNegativesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + TrueNegatives instance = new TrueNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(3.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + TrueNegatives instance = new TrueNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(4.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + TrueNegatives instance = + new TrueNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {2.f, 5.f, 7.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = tf.constant(new double[][] {{0.0, 2.0, 3.0, 5.0}}); + TrueNegatives instance = + new TrueNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {5., 15., 23.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java new file mode 100644 index 00000000000..c22c1245d97 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class TruePositivesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + TruePositives instance = new TruePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(7.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + TruePositives instance = new TruePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(12.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + TruePositives instance = + new TruePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {6.f, 3.f, 1.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = tf.constant(37.); + TruePositives instance = + new TruePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {222., 111., 37.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 63d666f8640..4330fa0aed7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -63,6 +63,7 @@ private void testValid( TestSession testSession, Ops tf, Operand weights, Operand values, Class type) { Op staticOp = MetricsHelper.assertBroadcastable(tf, weights, values); + testSession.run(staticOp); // dynamic test Operand weightsPlaceholder = tf.placeholder(type); From f258e38c2c90735d1b3b3642cba1bf989fd58229 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 7 Feb 2021 13:46:01 -0500 Subject: [PATCH 50/97] Reformat code, fix javadoc --- .../org/tensorflow/framework/metrics/AUC.java | 61 ++++++++++--------- .../framework/metrics/AUCSummationMethod.java | 8 +-- .../framework/metrics/Accuracy.java | 2 +- .../framework/metrics/BinaryAccuracy.java | 2 +- .../framework/metrics/FalseNegatives.java | 9 ++- .../framework/metrics/FalsePositives.java | 8 +-- .../tensorflow/framework/metrics/MeanIoU.java | 6 +- .../framework/metrics/MeanRelativeError.java | 15 +++-- .../framework/metrics/MeanTensor.java | 2 +- .../framework/metrics/Precision.java | 55 +++++++++-------- .../framework/metrics/PrecisionAtRecall.java | 32 +++++++--- .../tensorflow/framework/metrics/Recall.java | 46 +++++++------- .../framework/metrics/RecallAtPrecision.java | 39 ++++++++---- .../metrics/RootMeanSquaredError.java | 13 ++-- .../metrics/SensitivityAtSpecificity.java | 23 +++---- .../metrics/SparseCategoricalAccuracy.java | 20 +++--- .../metrics/SpecificityAtSensitivity.java | 25 ++++---- .../org/tensorflow/framework/metrics/Sum.java | 2 - .../metrics/TopKCategoricalAccuracy.java | 12 ++-- .../framework/metrics/TrueNegatives.java | 6 +- .../framework/metrics/TruePositives.java | 7 +-- .../metrics/impl/ConfusionMatrixEnum.java | 13 +++- .../framework/metrics/impl/MetricsHelper.java | 10 +-- .../impl/SensitivitySpecificityBase.java | 19 +++--- .../metrics/impl/WeightsBroadcastOps.java | 2 +- 25 files changed, 239 insertions(+), 198 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 62311c3cda5..da89167e1f3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -34,27 +34,29 @@ /** * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. * - *

This metric creates four local variables, truePositives`, trueNegatives`, - * falsePositives` and falseNegatives` that are used to compute the AUC. To discretize the AUC - * curve, a linearly spaced set of thresholds is used to compute pairs of recall and precision - * values. The area under the ROC-curve is therefore computed using the height of the recall values - * by the false positive rate, while the area under the PR-curve is the computed using the height of - * the precision values by the recall. + *

This metric creates four local variables, truePositives, trueNegatives + * , falsePositives and falseNegatives that are used to compute the + * AUC. To discretize the AUC curve, a linearly spaced set of thresholds is used to compute pairs of + * recall and precision values. The area under the ROC-curve is therefore computed using the height + * of the recall values by the false positive rate, while the area under the PR-curve is the + * computed using the height of the precision values by the recall. * - *

This value is ultimately returned as auc, an idempotent operation that computes the area - * under a discretized curve of precision versus recall values (computed using the aforementioned - * variables). The numThresholds variable controls the degree of discretization with larger - * numbers of thresholds more closely approximating the true AUC. The quality of the approximation - * may vary dramatically depending on numThresholds`. The thresholds parameter can be used to - * manually specify thresholds which split the predictions more evenly. + *

This value is ultimately returned as auc, an idempotent operation that computes + * the area under a discretized curve of precision versus recall values (computed using the + * aforementioned variables). The numThresholds variable controls the degree of + * discretization with larger numbers of thresholds more closely approximating the true AUC. The + * quality of the approximation may vary dramatically depending on numThresholds. The + * thresholds parameter can be used to manually specify thresholds which split the + * predictions more evenly. + * + *

For best results, predictions should be distributed approximately uniformly in + * the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor + * if this is not the case. Setting summationMethod to minoring or + * majoring can help quantify the error in the approximation by providing lower or upper + * bound estimate of the AUC. + * + *

Usage:
* - *

For best results, predictions should be distributed approximately uniformly in the range [0, - * 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor if this is not - * the case. Setting summationMethod to minoring or majoring can help quantify the error in - * the approximation by providing lower or upper bound estimate of the AUC. - *

- *

- * Usage:
*

  * AUC m = new  getTF().keras.metrics.AUC( getTF(), 3);
  * m.updateState( getTF().constant(new float[] {0, 0, 1,1}),
@@ -64,10 +66,11 @@
  * // tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
  * // recall = [1, 0.5, 0], fpRate = [1, 0, 0]
  * // auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75
- * Operand<TFloat32> result = m.result();
+ * Operand<TFloat32> result = m.result();
  * System.out.println(result.data().getFloat());
  * 0.75
  * 
+ * *
  * m.resetStates()
  * m.updateState( getTF().constant(new float[] {0, 0, 1, 1}),
@@ -170,7 +173,7 @@ public AUC(Ops tf, String name, long seed, Class type) {
    *
    * @param tf The TensorFlow Ops
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param seed the seed for random number generation. An initializer created with a given seed
    *     will always produce the same random tensor for a given shape and data type.
    * @param type the data type for the confusion matrix variables.
@@ -224,7 +227,7 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) {
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME}
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param seed the seed for random number generation. An initializer created with a given seed
    *     will always produce the same random tensor for a given shape and data type.
    * @param type the data type for the confusion matrix variables.
@@ -279,7 +282,7 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) {
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME}
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param seed the seed for random number generation. An initializer created with a given seed
@@ -336,7 +339,7 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C
    *
    * @param tf The TensorFlow Ops
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param seed the seed for random number generation. An initializer created with a given seed
@@ -392,7 +395,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type)
    *
    * @param tf The TensorFlow Ops
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param summationMethod Specifies the Riemann summation method used
@@ -442,7 +445,7 @@ public AUC(
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME}
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1.
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param summationMethod Specifies the Riemann summation method used,
@@ -462,7 +465,7 @@ public AUC(
   }
 
   /**
-   * Creates an AUC (Area under the curve) metric. using null> for the numThresholds,
+   * Creates an AUC (Area under the curve) metric. using null for the numThresholds,
    * false for multiLabel, and null for labelWeights.
    *
    * @param tf The TensorFlow Ops
@@ -493,7 +496,7 @@ public AUC(
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}.
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS}
+   *     must be > 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS}
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param summationMethod Specifies the Riemann summation method used
@@ -577,7 +580,7 @@ public AUC(
                       .greaterEqual(
                           labelWeights, cast(getTF(), getTF().constant(0), labelWeights.type())),
                   Collections.singletonList(
-                      getTF().constant("All values of `labelWeights` must be non-negative.")));
+                      getTF().constant("All values of labelWeights must be non-negative.")));
 
       Ops ltf =
           getTF()
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java
index 09581c726d3..60687dd9005 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java
@@ -17,11 +17,11 @@
 /**
  * Specifies the Riemann summation method used. {@link #INTERPOLATION} (default) applies mid-point
  * summation scheme for ROC. For PR-AUC, interpolates (true/false) positives but not the ratio that
- * is precision (see Davis & Goadrich 2006 for details); {@link #MINORING} applies left summation
- * for increasing intervals and right summation for decreasing intervals; {@link #MAJORING} does the
- * opposite.
+ * is precision (see Davis & Goadrich 2006 for details); {@link #MINORING} applies left
+ * summation for increasing intervals and right summation for decreasing intervals; {@link
+ * #MAJORING} does the opposite.
  *
- * @see Davis & Goadrich. 2006
+ * @see Davis & Goadrich. 2006
  * @see Riemann summation method
  */
 public enum AUCSummationMethod {
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java
index f69170e57b9..9548fb42c65 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java
@@ -32,7 +32,7 @@
  * ultimately returned as binary accuracy: an idempotent operation that simply divides total by
  * count.
  *
- * 

If sampleWeights is null>, weights default to 1. Use sampleWeights of 0 to mask + *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask * values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index 9e7f0f874cc..d2a414fdeb7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -30,7 +30,7 @@ * ultimately returned as binary accuracy: an idempotent operation that simply divides total by * count. * - *

If sampleWeights is null>, weights default to 1. Use sampleWeights of 0 to mask + *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask * values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java index cf6f84af512..39d33dda665 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -26,13 +26,12 @@ * This metric creates one local variable, accumulator that is used to keep track of * the number of false negatives. * - *

If sampleWeightsnull - * sampleWeightsIf sampleWeights is null, weights default to 1. Use + * sampleWeights of 0 to mask values. + * * @param The data type for the metric result */ -public class FalseNegatives - extends ConfusionMatrixConditionCount { +public class FalseNegatives extends ConfusionMatrixConditionCount { /** * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name and a diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java index 629caaafb52..3cf9fc0a5e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -26,14 +26,12 @@ * This metric creates one local variable, accumulator that is used to keep track of * the number of false positives. * - *

If sampleWeightsnull - * sampleWeightsIf sampleWeights is null, weights default to 1. Use + * sampleWeights of 0 to mask values. * * @param The data type for the metric result */ -public class FalsePositives< T extends TNumber> - extends ConfusionMatrixConditionCount { +public class FalsePositives extends ConfusionMatrixConditionCount { /** * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name and a diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index c8205565802..19b13ed391c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -39,8 +39,8 @@ * / (true_positive + false_positive + false_negative). The predictions are accumulated in a * confusion matrix, weighted by sample_weight and the metric is then calculated from it. * - *

If sampleWeight is null, weights default to 1. Use sample_weight of 0 to mask - * values. + *

If sampleWeight is null, weights default to 1. Use sample_weight of + * 0 to mask values. * * @param The data type for the metric result */ @@ -72,7 +72,7 @@ protected MeanIoU(Ops tf, long numClasses, long seed, Class type) { } /** - * create a metric with reduction = AUTO + * Creates a MeanIoU metric * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index eb8ccaf76d2..4c48c0f88a7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -42,17 +42,20 @@ public class MeanRelativeError extends Mean { private Operand normalizer; /** - * create a metric with name = class name and reduction = AUTO + * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name * * @param tf the TensorFlow Ops * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result */ protected MeanRelativeError(Ops tf, float[] normalizer, long seed, Class type) { this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); } /** - * create a metric with reduction = AUTO + * Creates a MeanRelativeError metric * * @param tf the TensorFlow Ops * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. @@ -66,7 +69,7 @@ protected MeanRelativeError(Ops tf, String name, float[] normalizer, long seed, } /** - * Creates a MeanRelativeError metric with name = class name and reduction = AUTO + * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name * * @param tf the TensorFlow Ops * @param normalizer The normalizer values with same shape as predictions. @@ -79,7 +82,7 @@ protected MeanRelativeError(Ops tf, double[] normalizer, long seed, Class typ } /** - * create a metric with reduction = AUTO + * Creates a MeanRelativeError metric * * @param tf the TensorFlow Ops * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. @@ -93,7 +96,7 @@ protected MeanRelativeError(Ops tf, String name, double[] normalizer, long seed, } /** - * create a metric with name = class name and reduction = AUTO + * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name * * @param tf the TensorFlow Ops * @param normalizer The normalizer values with same shape as predictions. @@ -106,7 +109,7 @@ protected MeanRelativeError(Ops tf, Operand normalizer, long seed, Class t } /** - * create a metric + * Creates a MeanRelativeError metric * * @param tf the TensorFlow ops * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index d9c767965a6..3d6d8194aac 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -103,7 +103,7 @@ private boolean init(Shape shape) { } } - /** {@inheritDoc */ + /** {@inheritDoc} */ @Override public List updateStateList( Operand values, Operand sampleWeights) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index 6b70c6680cb..ee87cebfa48 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -31,19 +31,22 @@ /** * Computes the precision of the predictions with respect to the labels. * - *

The metric creates two local variables, truePositives and falsePositives that are used to - * compute the precision. This value is ultimately returned as precision, an idempotent operation - * that simply divides truePositives by the sum of truePositives and falsePositives. + *

The metric creates two local variables, truePositives and falsePositives + * that are used to compute the precision. This value is ultimately returned as precision, + * an idempotent operation that simply divides truePositives by the sum of + * truePositives and falsePositives. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values. + *

If sampleWeights is null, weights default to 1. Use sampleWeights of + * 0 to mask values. * - *

If is set, the metric calculates precision as how often on average a class among the top-k - * classes with the highest predicted values of a batch entry is correct and can be found in the - * label for that entry. + *

If topK is set, the metric calculates precision as how often on average a class + * among the top-k classes with the highest predicted values of a batch entry is correct and can be + * found in the label for that entry. * - *

If classId is specified, the metric calculates precision by considering only the entries in the batch - * for which classId is above the thresholds and/or in the top-k highest predictions, and computing - * the fraction of them for which classId is indeed a correct label. + *

If classId is specified, the metric calculates precision by considering only the + * entries in the batch for which classId is above the thresholds and/or + * in the top-k highest predictions, and computing the fraction of them for which classId + * is indeed a correct label. * * @param The data type for the metric result */ @@ -58,13 +61,13 @@ public class Precision extends Metric { private final String truePositivesName; private final String falsePositivesName; private final Class type; + private final List initializers = new ArrayList<>(); private Variable truePositives; private Variable falsePositives; - private final List initializers = new ArrayList<>(); /** * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId - * values and with a threshold of {@link #DEFAULT_THRESHOLD).} + * values and with a threshold of {@link #DEFAULT_THRESHOLD}. * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed @@ -77,7 +80,7 @@ public Precision(Ops tf, long seed, Class type) { /** * Creates a Precision Metric with no topK or classId values with a threshold of {@link - * #DEFAULT_THRESHOLD).} + * #DEFAULT_THRESHOLD}. * * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link @@ -276,11 +279,8 @@ private void init() { Operand zero = zeros.call(tf.constant(Shape.of(thresholds.length)), type); if (this.truePositives == null) { - this.truePositives = - tf.withName(truePositivesName) - .variable(zero); + this.truePositives = tf.withName(truePositivesName).variable(zero); initializers.add(tf.assign(truePositives, zero)); - } if (this.falsePositives == null) { this.falsePositives = @@ -293,8 +293,10 @@ private void init() { /** {@inheritDoc} */ @Override @SuppressWarnings("unchecked") - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { Map> confusionMatrix = new HashMap<>(); confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives); @@ -314,7 +316,7 @@ public List updateStateList( thresholds, topK, classId, - tSampleWeights, + tSampleWeights, false, null)); } @@ -323,8 +325,7 @@ public List updateStateList( @Override public Operand result() { Ops tf = getTF(); - Operand result = - tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + Operand result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); return thresholds.length == 1 ? tf.slice( result, @@ -351,7 +352,7 @@ public float[] getThresholds() { /** * Gets the topK value, may be null * - * @return the topK + * @return the topK value or null */ public Integer getTopK() { return topK; @@ -360,7 +361,7 @@ public Integer getTopK() { /** * Gets the classId, may be null * - * @return the classId + * @return the classId or null */ public Integer getClassId() { return classId; @@ -375,7 +376,11 @@ public Variable getTruePositives() { return truePositives; } - /** Gets the falsePositives variable return the falsePositives */ + /** + * Gets the falsePositives variable + * + * @return the falsePositives + */ public Variable getFalsePositives() { return falsePositives; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 2ec66df0ca9..299c649279f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -24,10 +24,17 @@ /** * Computes best precision where recall is >= specified value. + * + *

This metric creates four local variables, truePositives, trueNegatives, falsePositives and + * falseNegatives that are used to compute the precision at the given recall. The threshold for the + * given recall value is computed and used to evaluate the corresponding precision. + * + *

If sampleWeights is null, weights default to 1. Use sampleWeights of + * 0 to mask values. + * * @param The data type for the metric result */ -public class PrecisionAtRecall - extends SensitivitySpecificityBase { +public class PrecisionAtRecall extends SensitivitySpecificityBase { private final float recall; @@ -40,7 +47,8 @@ public class PrecisionAtRecall * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public PrecisionAtRecall(Ops tf, float recall, long seed, Class type) { this(tf, null, recall, DEFAULT_NUM_THRESHOLDS, seed, type); @@ -56,7 +64,8 @@ public PrecisionAtRecall(Ops tf, float recall, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public PrecisionAtRecall(Ops tf, String name, float recall, long seed, Class type) { this(tf, name, recall, DEFAULT_NUM_THRESHOLDS, seed, type); @@ -72,7 +81,8 @@ public PrecisionAtRecall(Ops tf, String name, float recall, long seed, Class * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Class type) { this(tf, null, recall, numThresholds, seed, type); @@ -89,7 +99,8 @@ public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Cla * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public PrecisionAtRecall( Ops tf, String name, float recall, int numThresholds, long seed, Class type) { @@ -104,8 +115,7 @@ public Operand result() { Ops tf = getTF(); Operand recall = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); Operand sub = tf.math.sub(recall, cast(tf, tf.constant(value), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); @@ -115,7 +125,11 @@ public Operand result() { return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); } - /** @return the recall */ + /** + * Gets the recall value + * + * @return the recall value + */ public float getRecall() { return recall; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 0672b78f229..e1eebb98f77 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -30,22 +30,25 @@ /** * Computes the recall of the predictions with respect to the labels. - *

This metric creates two local - * variables, truePositives and falseNegatives, that are used to compute the recall. This value is - * ultimately returned as recall, an idempotent operation that simply divides truePositives by the sum of truePositives and falseNegatives. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask values. + *

This metric creates two local variables, truePositives and falseNegatives + * , that are used to compute the recall. This value is ultimately returned as recall, an + * idempotent operation that simply divides truePositives by the sum of + * truePositives and falseNegatives. * - *

If is set, the metric calculates recall as how often on average a class among the labels of a - * batch entry is in the top-k predictions. + *

If sampleWeights is null, weights default to 1. Use sampleWeights of + * 0 to mask values. * - *

If classId is specified, the metric calculates recall by considering only the entries in the batch - * for which classId is in the label, and computing the fraction of them for which classId is above - * the threshold and/or in the top-k predictions. + *

If topK is set, the metric calculates recall as how often on average a class + * among the labels of a batch entry is in the top-k predictions. + * + *

If classId is specified, the metric calculates recall by considering only the + * entries in the batch for which classId is in the label, and computing the fraction + * of them for which classId is above the threshold and/or in the top-k predictions. * * @param The data type for the metric result */ -public class Recall< T extends TNumber> extends Metric< T> { +public class Recall extends Metric { public static final float DEFAULT_THRESHOLD = 0.5f; public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; @@ -56,9 +59,9 @@ public class Recall< T extends TNumber> extends Metric< T> { private final String truePositivesName; private final String falseNegativesName; private final Class type; + private final List initializers = new ArrayList<>(); private Variable truePositives; private Variable falseNegatives; - private final List initializers = new ArrayList<>(); /** * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set @@ -301,17 +304,13 @@ private void init() { Operand zero = zeros.call(tf.constant(Shape.of(this.thresholds.length)), type); if (truePositives == null) { - truePositives = - tf.withName(truePositivesName) - .variable(zero); + truePositives = tf.withName(truePositivesName).variable(zero); initializers.add(tf.assign(truePositives, zero)); } - + if (this.falseNegatives == null) { - falseNegatives = - tf.withName(falseNegativesName) - .variable(zero); + falseNegatives = tf.withName(falseNegativesName).variable(zero); initializers.add(tf.assign(falseNegatives, zero)); } } @@ -326,7 +325,9 @@ public Op resetStates() { @Override @SuppressWarnings("unchecked") public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + Operand labels, + Operand predictions, + Operand sampleWeights) { Map> confusionMatrix = new HashMap<>(); confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives); @@ -345,17 +346,16 @@ public List updateStateList( this.thresholds, this.topK, this.classId, - tSampleWeights, + tSampleWeights, false, null); } @Override public Operand result() { - Ops tf = getTF(); + Ops tf = getTF(); Operand result = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); return this.thresholds.length == 1 ? tf.slice(result, tf.constant(new int[] {0}), tf.constant(new int[1])) : result; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index 6c774f0c765..fb6890d1e01 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -24,8 +24,22 @@ import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; import static org.tensorflow.framework.utils.CastHelper.cast; -public class RecallAtPrecision - extends SensitivitySpecificityBase { +/** + * Computes best recall where precision is >= specified value. + * + *

For a given score-label-distribution the required precision might not be achievable, in this + * case 0.0 is returned as recall. + * + *

This metric creates four local variables, truePositives, trueNegatives, falsePositives and + * falseNegatives that are used to compute the recall at the given precision. The threshold for the + * given precision value is computed and used to evaluate the corresponding recall. + * + *

If sampleWeights is null, weights default to 1. Use sampleWeights of + * 0 to mask values. + * + * @param The data type for the metric result + */ +public class RecallAtPrecision extends SensitivitySpecificityBase { private final float precision; @@ -38,7 +52,8 @@ public class RecallAtPrecision * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public RecallAtPrecision(Ops tf, float precision, long seed, Class type) { this(tf, null, precision, DEFAULT_NUM_THRESHOLDS, seed, type); @@ -54,7 +69,8 @@ public RecallAtPrecision(Ops tf, float precision, long seed, Class type) { * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public RecallAtPrecision(Ops tf, String name, float precision, long seed, Class type) { this(tf, name, precision, DEFAULT_NUM_THRESHOLDS, seed, type); @@ -70,7 +86,8 @@ public RecallAtPrecision(Ops tf, String name, float precision, long seed, Class< * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, Class type) { this(tf, null, precision, numThresholds, seed, type); @@ -87,7 +104,8 @@ public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range [0-1]. + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. */ public RecallAtPrecision( Ops tf, String name, float precision, int numThresholds, long seed, Class type) { @@ -103,18 +121,15 @@ public Operand result() { Ops tf = getTF(); Operand precisions = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falsePositives)); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falsePositives)); Operand recalls = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); Operand isFeasible = tf.math.greaterEqual(precisions, cast(tf, tf.constant(this.value), getType())); Where feasible = tf.where(isFeasible); Operand feasibleExists = tf.math.greater(tf.size(feasible), tf.constant(0)); - Operand gather = - tf.expandDims(tf.gather(recalls, feasible, tf.constant(0)), tf.constant(0)); + Operand gather = tf.expandDims(tf.gather(recalls, feasible, tf.constant(0)), tf.constant(0)); return tf.select( feasibleExists, tf.reduceMax(gather, allAxes(tf, gather)), diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 2133642564b..9b4401964d7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -27,12 +27,12 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes root mean squared error metric between labels> and predictions + * Computes root mean squared error metric between labels and predictions * . * * @param The data type for the metric result */ -public class RootMeanSquaredError< T extends TNumber> extends Mean< T> { +public class RootMeanSquaredError extends Mean { /** * Creates a RootMeanSquaredError metric with a name of {@link Class#getSimpleName()} @@ -62,12 +62,15 @@ public RootMeanSquaredError(Ops tf, String name, long seed, Class type) { /** {@inheritDoc} */ @Override - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); - Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + Operand tSampleWeights = + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; LossTuple ops = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); tPredictions = ops.getTarget(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 7cf694868e6..2c7420a5518 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -23,7 +23,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes best sensitivity where sensitivity is >= specified value. + * Computes best sensitivity where sensitivity is >= specified value. * *

Sensitivity measures the proportion of actual positives that are correctly * identified as such (tp / (tp + fn)). @@ -36,15 +36,14 @@ * sensitivity at the given specificity. The threshold for the given specificity value is computed * and used to evaluate the corresponding sensitivity. * - *

If sampleWeights is null>, weights default to 1. Use sample_weight - * of 0 to mask values. + *

If sampleWeights is null, weights default to 1. Use sample_weight of + * 0 to mask values. * * @see Additional information * about specificity and sensitivity * @param The data type for the metric result */ -public class SensitivityAtSpecificity - extends SensitivitySpecificityBase { +public class SensitivityAtSpecificity extends SensitivitySpecificityBase { private final float specificity; @@ -57,7 +56,7 @@ public class SensitivityAtSpecificity * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ public SensitivityAtSpecificity(Ops tf, float specificity, long seed, Class type) { @@ -74,7 +73,7 @@ public SensitivityAtSpecificity(Ops tf, float specificity, long seed, Class t * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ public SensitivityAtSpecificity( @@ -92,7 +91,7 @@ public SensitivityAtSpecificity( * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ public SensitivityAtSpecificity( @@ -111,7 +110,7 @@ public SensitivityAtSpecificity( * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range * [0-1]. */ public SensitivityAtSpecificity( @@ -127,10 +126,8 @@ public SensitivityAtSpecificity( public Operand result() { Ops tf = getTF(); Operand specificities = - tf.math.divNoNan( - this.trueNegatives, tf.math.add(this.trueNegatives, this.falsePositives)); - Operand sub = - tf.math.sub(specificities, cast(tf, tf.constant(this.getValue()), getType())); + tf.math.divNoNan(this.trueNegatives, tf.math.add(this.trueNegatives, this.falsePositives)); + Operand sub = tf.math.sub(specificities, cast(tf, tf.constant(this.getValue()), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 156a4995b02..7034861d8d2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -31,24 +31,23 @@ /** * Calculates how often predictions matches integer labels. * - *

You can provide logits of classes as predictions, since argmax of logits and probabilities are - * same. + *

You can provide logits of classes as predictions, since argmax of logits and + * probabilities are same. * *

This metric creates two local variables, `total` and `count` that are used to compute the - * frequency with which predictions matches labels. This frequency is ultimately returned as `sparse - * categorical accuracy`: an idempotent operation that simply divides `total` by `count`. + * frequency with which predictions matches labels. This frequency is + * ultimately returned as `sparse categorical accuracy`: an idempotent operation that simply divides + * `total` by `count`. * *

If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values.' * *

Usage: * - *

- * *

  * SparseCategoricalAccuracy m = new tf.keras.metrics.SparseCategoricalAccuracy();
  * m.update_state(tf.constant(new float[][] {{2}, {1}},
  *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, [{0.05f, 0.95f, 0f}});
- * Operand<TFloat32>> result = m.result();
+ * Operand<TFloat32> result = m.result();
  * System.out.println(result.data().getFloat());
  * 0.5
  * 
@@ -87,7 +86,7 @@ public class SparseCategoricalAccuracy extends MeanMetricWrap * will always produce the same random tensor for a given shape and data type. * @param type The data type for the metric result */ - public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { + public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { this(tf, null, seed, type); } @@ -100,7 +99,7 @@ public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { * will always produce the same random tensor for a given shape and data type. * @param type the type of the metric result. */ - public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); super.setLoss(this); } @@ -108,8 +107,7 @@ public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) /** {@inheritDoc} */ @Override public Operand call( - Operand labels, - Operand predictions) { + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index 59f6f44c1f2..d0b797690bd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -23,7 +23,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes best specificity where sensitivity is >= specified value. Sensitivity + * Computes best specificity where sensitivity is >= specified value. Sensitivity * measures the proportion of actual positives that are correctly identified as such * (tp / (tp + fn)). * @@ -35,15 +35,14 @@ * specificity at the given sensitivity. The threshold for the given sensitivity value is computed * and used to evaluate the corresponding specificity. * - *

If sampleWeights is null>, weights default to 1. Use sample_weight - * of 0 to mask values. + *

If sampleWeights is null, weights default to 1. Use sample_weight of + * 0 to mask values. * * @see Additional information * about specificity and sensitivity * @param The data type for the metric result */ -public class SpecificityAtSensitivity - extends SensitivitySpecificityBase { +public class SpecificityAtSensitivity extends SensitivitySpecificityBase { private final float sensitivity; @@ -56,7 +55,7 @@ public class SpecificityAtSensitivity * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ public SpecificityAtSensitivity(Ops tf, float sensitivity, long seed, Class type) { @@ -73,7 +72,7 @@ public SpecificityAtSensitivity(Ops tf, float sensitivity, long seed, Class t * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ public SpecificityAtSensitivity( @@ -91,7 +90,7 @@ public SpecificityAtSensitivity( * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ public SpecificityAtSensitivity( @@ -110,7 +109,7 @@ public SpecificityAtSensitivity( * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range * [0-1]. */ public SpecificityAtSensitivity( @@ -124,14 +123,12 @@ public SpecificityAtSensitivity( /** {@inheritDoc} */ @Override public Operand result() { - + Ops tf = getTF(); Operand sensitivities = - tf.math.divNoNan( - this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); - Operand sub = - tf.math.sub(sensitivities, cast(tf, tf.constant(this.getValue()), getType())); + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + Operand sub = tf.math.sub(sensitivities, cast(tf, tf.constant(this.getValue()), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java index 4312d7a97f0..a3241221b66 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -28,8 +28,6 @@ * values. This is ultimately returned as sum. * *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. - * - */ public class Sum extends Reduce { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index d2db4f368ac..ad78e48bc34 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -22,12 +22,13 @@ import static org.tensorflow.framework.utils.CastHelper.cast; -/** Computes the poisson loss metric between labels and predictions. +/** + * Computes the poisson loss metric between labels and predictions. * * @param The data type for the metric result */ -public class TopKCategoricalAccuracy - extends MeanMetricWrapper implements LossMetric { +public class TopKCategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { public static final int DEFAULT_K = 5; /** Number of top elements to look at for computing accuracy. */ private final int k; @@ -40,6 +41,7 @@ public class TopKCategoricalAccuracy * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result */ public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { this(tf, name, DEFAULT_K, seed, type); @@ -53,6 +55,7 @@ public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { * @param k Number of top elements to look at for computing accuracy. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result */ public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { super(tf, name, seed, type); @@ -62,7 +65,8 @@ public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class t /** {@inheritDoc} */ @Override - public Operand call(Operand labels, Operand predictions) { + public Operand call( + Operand labels, Operand predictions) { Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); return Metrics.topKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java index de6428fed88..91b6751588a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -26,14 +26,12 @@ * This metric creates one local variable, accumulator that is used to keep track of * the number of true negatives. * - *

If sampleWeightsnull, weights - * default to 1. Use + *

If sampleWeights is null, weights default to 1. Use * sampleWeights of 0 to mask values. * * @param The data type for the metric result */ -public class TrueNegatives - extends ConfusionMatrixConditionCount { +public class TrueNegatives extends ConfusionMatrixConditionCount { /** * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name and a diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java index c573b6b5719..b67d381a62d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -26,13 +26,12 @@ * This metric creates one local variable, accumulator that is used to keep track of * the number of true positives. * - *

If sampleWeightsnull, weights - * default to 1. Use + *

If sampleWeights is null, weights default to 1. Use * sampleWeights of 0 to mask values. + * * @param The data type for the metric result */ -public class TruePositives - extends ConfusionMatrixConditionCount< T> { +public class TruePositives extends ConfusionMatrixConditionCount { /** * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name and a diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java index b76356661a9..281aa2072d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java @@ -27,7 +27,12 @@ public enum ConfusionMatrixEnum { private final String abbrev; - /** Creates a ConfusionMatrixEnum */ + /** + * Creates a ConfusionMatrixEnum + * + * @param abbrev the abbreviation for the confusion condition as required by the underlying + * TensorFlow api. + */ ConfusionMatrixEnum(String abbrev) { this.abbrev = abbrev; } @@ -50,7 +55,11 @@ public static ConfusionMatrixEnum get(String item) { return null; } - /** Gets the abbreviation for this enum value */ + /** + * Gets the abbreviation for this enum value + * + * @return the abbreviation for this enum value as required by the underlying TensorFlow api. + */ public String getAbbreviation() { return abbrev; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index cbb24933967..0be0a7a572a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -283,10 +283,10 @@ public static List assertShapes( *

For every pair of values in labels and predictions: * *

-   * TRUE_POSITIVES:  labels == true and predictions > thresholds
-   * FALSE_POSITIVES: labels == true and predictions <= thresholds
-   * TRUE_NEGATIVES:  labels == false and predictions <= thresholds
-   * FALSE_NEGATIVE:  labels == false and predictions > thresholds
+   * TRUE_POSITIVES:  labels == true and predictions > thresholds
+   * FALSE_POSITIVES: labels == true and predictions <= thresholds
+   * TRUE_NEGATIVES:  labels == false and predictions <= thresholds
+   * FALSE_NEGATIVE:  labels == false and predictions > thresholds
    * 
* *

The results will be weighted and added together. When multiple thresholds are provided, we @@ -324,7 +324,7 @@ public static List assertShapes( * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. * @param the data type for the variables * @throws IllegalArgumentException If predictions and labels have - * mismatched shapes, or if sampleWeight is not null>and its shape + * mismatched shapes, or if sampleWeight is not nulland its shape * doesn't match predictions * @return an op to update the given confusion matrix variables. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 3949ede822a..377124333bd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -55,7 +55,7 @@ public abstract class SensitivitySpecificityBase extends Metr * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the variables - * @throws IllegalArgumentException if numThresholds <= 0. + * @throws IllegalArgumentException if numThresholds <= 0. */ protected SensitivitySpecificityBase( Ops tf, String name, float value, int numThresholds, long seed, Class type) { @@ -114,28 +114,29 @@ private void init() { public Op initializeVariables() { List varInitializers = new ArrayList<>(); - if(truePositivesInitializer != null ) { + if (truePositivesInitializer != null) { varInitializers.add(truePositivesInitializer); } - if(falsePositivesInitializer != null ) { + if (falsePositivesInitializer != null) { varInitializers.add(falsePositivesInitializer); } - if(trueNegativesInitializer != null ) { + if (trueNegativesInitializer != null) { varInitializers.add(trueNegativesInitializer); } - if(falseNegativesInitializer != null ) { + if (falseNegativesInitializer != null) { varInitializers.add(falseNegativesInitializer); } return getTF().withControlDependencies(varInitializers).noOp(); - } /** {@inheritDoc} */ @Override @SuppressWarnings("unchecked") - public List updateStateList( - Operand labels, Operand predictions, Operand sampleWeights) { + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { Ops tf = getTF(); Operand tLabels = cast(tf, labels, type); Operand tPredictions = cast(tf, predictions, type); @@ -156,7 +157,7 @@ public List updateStateList( this.getThresholds(), null, null, - tSampleWeights, + tSampleWeights, false, null); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 09752798ad5..36792b8ea7a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -153,7 +153,7 @@ private static Operand hasValidDims( * *

This returns a version of weights following the same broadcast rules as * mul(weights, - * values), but limited to the weights shapes allowed by assertBroadcastable + * values), but limited to the weights shapes allowed by assertBroadcastable * When computing a weighted average, use this function to broadcast weights before * summing them; e.g., reduceSum(w * v) / reduceSum(_broadcast_weights(w, v)). * From 36e093498f40bcc1559b07dbec306f0f973a737f Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 2 Mar 2021 10:45:02 -0500 Subject: [PATCH 51/97] Change thresholds to Operand --- .../org/tensorflow/framework/metrics/AUC.java | 309 ++++++++---------- .../framework/metrics/Precision.java | 16 +- .../tensorflow/framework/metrics/Recall.java | 16 +- .../impl/ConfusionMatrixConditionCount.java | 9 +- .../framework/metrics/impl/MetricsHelper.java | 113 ++++--- .../impl/SensitivitySpecificityBase.java | 10 +- .../framework/metrics/PrecisionTest.java | 2 +- 7 files changed, 213 insertions(+), 262 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index da89167e1f3..8a31dfd3fce 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -58,9 +58,9 @@ *

Usage:
* *

- * AUC m = new  getTF().keras.metrics.AUC( getTF(), 3);
- * m.updateState( getTF().constant(new float[] {0, 0, 1,1}),
- *          getTF().constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
+ * AUC m = new  tf.keras.metrics.AUC( tf, 3);
+ * m.updateState( tf.constant(new float[] {0, 0, 1,1}),
+ *          tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
  *
  * // threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
  * // tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
@@ -73,9 +73,9 @@
  *
  * 
  * m.resetStates()
- * m.updateState( getTF().constant(new float[] {0, 0, 1, 1}),
- *                 getTF().constant(new float[] {0f, 0.5f, 0.3f, 0.9f}, ),
- *                 getTF().constant(new float[] {1, 0, 0, 1}));
+ * m.updateState( tf.constant(new float[] {0, 0, 1, 1}),
+ *                 tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}, ),
+ *                 tf.constant(new float[] {1, 0, 0, 1}));
  * result = m.result();
  * System.out.println(result.data().getFloat());
  * 1.0
@@ -209,7 +209,7 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) {
     this(
         tf,
         null,
-        null,
+            DEFAULT_NUM_THRESHOLDS,
         AUCCurve.ROC,
         AUCSummationMethod.INTERPOLATION,
         thresholds,
@@ -264,7 +264,7 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) {
     this(
         tf,
         name,
-        null,
+            DEFAULT_NUM_THRESHOLDS,
         AUCCurve.ROC,
         AUCSummationMethod.INTERPOLATION,
         thresholds,
@@ -322,7 +322,7 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C
     this(
         tf,
         name,
-        null,
+            DEFAULT_NUM_THRESHOLDS,
         curve,
         AUCSummationMethod.INTERPOLATION,
         thresholds,
@@ -378,7 +378,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type)
     this(
         tf,
         null,
-        null,
+            DEFAULT_NUM_THRESHOLDS,
         curve,
         AUCSummationMethod.INTERPOLATION,
         thresholds,
@@ -435,7 +435,7 @@ public AUC(
       AUCSummationMethod summationMethod,
       long seed,
       Class type) {
-    this(tf, null, null, curve, summationMethod, thresholds, false, null, seed, type);
+    this(tf, null, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type);
   }
 
   /**
@@ -487,7 +487,7 @@ public AUC(
       AUCSummationMethod summationMethod,
       long seed,
       Class type) {
-    this(tf, name, null, curve, summationMethod, thresholds, false, null, seed, type);
+    this(tf, name, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type);
   }
 
   /**
@@ -496,7 +496,7 @@ public AUC(
    * @param tf The TensorFlow Ops
    * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}.
    * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
-   *     must be > 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS}
+   *     must be > 1.
    * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link
    *     AUCCurve#PR} for the Precision-Recall-curve.
    * @param summationMethod Specifies the Riemann summation method used
@@ -520,7 +520,7 @@ public AUC(
   public AUC(
       Ops tf,
       String name,
-      Integer numThresholds,
+      int numThresholds,
       AUCCurve curve,
       AUCSummationMethod summationMethod,
       float[] thresholds,
@@ -529,10 +529,10 @@ public AUC(
       long seed,
       Class type) {
     super(tf, name == null ? DEFAULT_NAME : name, seed);
-    this.truePositivesName = this.getVariableName(TRUE_POSITIVES);
-    this.falsePositivesName = this.getVariableName(FALSE_POSITIVES);
-    this.trueNegativesName = this.getVariableName(TRUE_NEGATIVES);
-    this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES);
+    truePositivesName = getVariableName(TRUE_POSITIVES);
+    falsePositivesName = getVariableName(FALSE_POSITIVES);
+    trueNegativesName = getVariableName(TRUE_NEGATIVES);
+    falseNegativesName = getVariableName(FALSE_NEGATIVES);
     this.curve = curve;
     this.summationMethod = summationMethod;
     this.type = type;
@@ -540,18 +540,23 @@ public AUC(
     this.multiLabel = multiLabel;
 
     if (thresholds != null) { // ignore numThresholds
-      for (float t : thresholds)
-        if (t < 0.0f || t > 1.0f)
+      for (float t : thresholds) {
+        if (t < 0.0f || t > 1.0f) {
           throw new IllegalArgumentException(
               String.format(
                   "Threshold values must be in [0, 1]. Invalid values: %s",
                   Arrays.toString(thresholds)));
+        }
+      }
       this.numThresholds = thresholds.length + 2;
       Arrays.sort(thresholds);
     } else {
-      if (numThresholds <= 1) throw new IllegalArgumentException("numThresholds must be > 1.");
+
+      if (numThresholds <= 1) {
+        throw new IllegalArgumentException("numThresholds must be > 1.");
+      }
       this.numThresholds = numThresholds;
-      thresholds = new float[numThresholds - 2];
+      thresholds = new float[this.numThresholds - 2];
       // linearly interpolate (numThresholds - 2) thresholds between endpoints
       for (int i = 0; i < thresholds.length; i++) {
         thresholds[i] = (i + 1) * 1.0f / (this.numThresholds - 1);
@@ -559,39 +564,38 @@ public AUC(
     }
     // Add an endpoint "threshold" below zero and above one for either
     // threshold method to account for floating point imprecision.
-    if (thresholds.length != this.numThresholds - 2)
+    if (thresholds.length != this.numThresholds - 2) {
       throw new IllegalArgumentException(
           "Thresholds length must contain numThresholds - 2 entries");
+    }
+    // Add an endpoint "threshold" below zero and above one for either
+    // threshold method to account for floating point imprecisions.
     this.thresholds = new float[this.numThresholds];
     this.thresholds[0] = -EPSILON;
     System.arraycopy(thresholds, 0, this.thresholds, 1, thresholds.length);
     this.thresholds[this.numThresholds - 1] = 1 + EPSILON;
 
+    // # Handle multilabel arguments.
+
     if (labelWeights != null) {
       // assert that labelWeights are non-negative.
 
       this.labelWeights = labelWeights;
       Op checks =
-          getTF()
-              .withSubScope("AUC")
+          tf.withSubScope("AUC")
               .assertThat(
-                  getTF()
-                      .math
-                      .greaterEqual(
-                          labelWeights, cast(getTF(), getTF().constant(0), labelWeights.type())),
+                  tf.math.greaterEqual(labelWeights, cast(tf, tf.constant(0), labelWeights.type())),
                   Collections.singletonList(
-                      getTF().constant("All values of labelWeights must be non-negative.")));
+                      tf.constant("All values of labelWeights must be non-negative.")));
 
       Ops ltf =
-          getTF()
-              .withSubScope("updateState")
-              .withControlDependencies(Collections.singletonList(checks));
+          tf.withSubScope("updateState").withControlDependencies(Collections.singletonList(checks));
 
       this.labelWeights = ltf.identity(this.labelWeights);
     }
 
-    if (this.multiLabel) {
-      this.numLabels = null;
+    if (multiLabel) {
+      numLabels = null;
     }
   }
 
@@ -607,6 +611,7 @@ private Map> build(Shape shape) {
     if (initialized) {
       return Collections.EMPTY_MAP;
     }
+    Ops tf = getTF();
 
     if (this.isMultiLabel()) {
       if (shape == null) {
@@ -623,26 +628,27 @@ private Map> build(Shape shape) {
       variableShape = Shape.of(this.numThresholds);
     }
 
+    // Create metric variables
     Zeros zeros = new Zeros<>(getTF());
-    Operand zero = zeros.call(getTF().constant(variableShape), type);
+    Operand zero = zeros.call(tf.constant(variableShape), type);
     if (truePositives == null) {
-      truePositives = getTF().withName(getTruePositivesName()).variable(zero);
-      initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, getTF().assign(truePositives, zero));
+      truePositives = tf.withName(getTruePositivesName()).variable(zero);
+      initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, tf.assign(truePositives, zero));
     }
 
     if (falsePositives == null) {
-      falsePositives = getTF().withName(getFalsePositivesName()).variable(zero);
-      initializers.put(ConfusionMatrixEnum.FALSE_POSITIVES, getTF().assign(falsePositives, zero));
+      falsePositives = tf.withName(getFalsePositivesName()).variable(zero);
+      initializers.put(ConfusionMatrixEnum.FALSE_POSITIVES, tf.assign(falsePositives, zero));
     }
 
     if (trueNegatives == null) {
-      trueNegatives = getTF().withName(getTrueNegativesName()).variable(zero);
-      initializers.put(ConfusionMatrixEnum.TRUE_NEGATIVES, getTF().assign(trueNegatives, zero));
+      trueNegatives = tf.withName(getTrueNegativesName()).variable(zero);
+      initializers.put(ConfusionMatrixEnum.TRUE_NEGATIVES, tf.assign(trueNegatives, zero));
     }
 
     if (falseNegatives == null) {
-      falseNegatives = getTF().withName(getFalseNegativesName()).variable(zero);
-      initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, getTF().assign(falseNegatives, zero));
+      falseNegatives = tf.withName(getFalseNegativesName()).variable(zero);
+      initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, tf.assign(falseNegatives, zero));
     }
 
     this.initialized = true;
@@ -656,19 +662,22 @@ public List updateStateList(
       Operand labels,
       Operand predictions,
       Operand sampleWeights) {
-
-    Operand lLabels = cast(getTF(), labels, type);
-    Operand lPredictions = cast(getTF(), predictions, type);
-    Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null;
+    Ops tf = getTF();
+    Operand tLabels = cast(tf, labels, type);
+    Operand tPredictions = cast(tf, predictions, type);
+    Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
     List updateOperations = new ArrayList<>();
     Map> varInitializers = Collections.EMPTY_MAP;
     if (!this.initialized) {
-      varInitializers = build(lPredictions.shape());
+      varInitializers = build(tPredictions.shape());
     }
     if (this.isMultiLabel() || this.getLabelWeights() != null) {
+      // labels should have shape (number of examples, number of labels).
       List> symbols = new ArrayList<>();
-      symbols.add(new SymbolicShape<>(lLabels, "N", "L"));
+      symbols.add(new SymbolicShape<>(tLabels, "N", "L"));
       if (this.isMultiLabel()) {
+        // TP, TN, FP, and FN should all have shape
+        //(number of thresholds, number of labels).
         symbols.add(new SymbolicShape<>(this.truePositives, "T", "L"));
         symbols.add(new SymbolicShape<>(this.falsePositives, "T", "L"));
         symbols.add(new SymbolicShape<>(this.trueNegatives, "T", "L"));
@@ -678,30 +687,34 @@ public List updateStateList(
         symbols.add(new SymbolicShape<>(this.getLabelWeights(), "L", ""));
       }
       updateOperations.addAll(
-          MetricsHelper.assertShapes(getTF(), symbols, "Number of labels is not consistent."));
-    }
-    if (this.isMultiLabel()) {
-      this.labelWeights = null;
+          MetricsHelper.assertShapes(tf, symbols, "Number of labels is not consistent."));
     }
+
     Map> confusionMatrix = new HashMap<>();
     confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives);
     confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.falsePositives);
     confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.trueNegatives);
     confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives);
 
+    // Only forward labelWeights to update_confusion_matrix_variables when
+    // multiLabel is false. Otherwise the averaging of individual label AUCs is
+    // handled in AUC.result
+    if (this.isMultiLabel()) {
+      this.labelWeights = null;
+    }
     updateOperations.addAll(
         MetricsHelper.updateConfusionMatrixVariables(
-            getTF(),
+            tf,
             confusionMatrix,
             varInitializers,
-            lLabels,
-            lPredictions,
-            this.thresholds,
+            tLabels,
+            tPredictions,
+            tf.constant(thresholds),
             null,
             null,
             tSampleWeights,
-            this.isMultiLabel(),
-            this.getLabelWeights()));
+            isMultiLabel(),
+            getLabelWeights()));
     return updateOperations;
   }
 
@@ -712,147 +725,84 @@ public List updateStateList(
    */
   private Operand interpolatePRAuc() {
     // truePositives[:self.numThresholds - 1]
+    Ops tf = getTF();
     Operand tp0 =
-        getTF()
-            .slice(
-                truePositives,
-                getTF().constant(new int[] {0}),
-                getTF().constant(new int[] {this.getNumThresholds() - 1}));
+        tf.slice(
+            truePositives,
+            tf.constant(new int[] {0}),
+            tf.constant(new int[] {this.getNumThresholds() - 1}));
     // truePositives[1:]
     Operand tp1 =
-        getTF()
-            .slice(
-                truePositives, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1}));
+        tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
 
-    Operand dTP = getTF().math.sub(tp0, tp1);
+    Operand dTP = tf.math.sub(tp0, tp1);
 
-    Operand p = getTF().math.add(truePositives, falsePositives);
+    Operand p = tf.math.add(truePositives, falsePositives);
 
     Operand dP =
-        getTF()
-            .math
-            .sub(
-                getTF()
-                    .slice(
-                        p,
-                        getTF().constant(new int[] {0}),
-                        getTF().constant(new int[] {this.getNumThresholds() - 1})),
-                getTF()
-                    .slice(p, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1})));
+        tf.math.sub(
+            tf.slice(
+                p,
+                tf.constant(new int[] {0}),
+                tf.constant(new int[] {this.getNumThresholds() - 1})),
+            tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})));
 
     Operand precisionSlope =
-        getTF()
-            .math
-            .divNoNan(
-                dTP, getTF().math.maximum(dP, getTF().dtypes.cast(getTF().constant(0), dP.type())));
+        tf.math.divNoNan(dTP, tf.math.maximum(dP, tf.dtypes.cast(tf.constant(0), dP.type())));
 
     Operand intercept =
-        getTF()
-            .math
-            .sub(
-                getTF()
-                    .slice(
-                        truePositives,
-                        getTF().constant(new int[] {1}),
-                        getTF().constant(new int[] {-1})),
-                getTF()
-                    .math
-                    .mul(
-                        precisionSlope,
-                        getTF()
-                            .slice(
-                                p,
-                                getTF().constant(new int[] {1}),
-                                getTF().constant(new int[] {-1}))));
+        tf.math.sub(
+            tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
+            tf.math.mul(
+                precisionSlope,
+                tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))));
 
     Operand safePRatio =
-        getTF()
-            .select(
-                getTF()
-                    .math
-                    .logicalAnd(
-                        getTF()
-                            .math
-                            .greater(
-                                getTF()
-                                    .slice(
-                                        p,
-                                        getTF().constant(new int[] {0}),
-                                        getTF().constant(new int[] {this.getNumThresholds() - 1})),
-                                getTF().dtypes.cast(getTF().constant(0), p.type())),
-                        getTF()
-                            .math
-                            .greater(
-                                getTF()
-                                    .slice(
-                                        p,
-                                        getTF().constant(new int[] {1}),
-                                        getTF().constant(new int[] {-1})),
-                                getTF().dtypes.cast(getTF().constant(0), p.type()))),
-                getTF()
-                    .math
-                    .divNoNan(
-                        getTF()
-                            .slice(
-                                p,
-                                getTF().constant(new int[] {0}),
-                                getTF().constant(new int[] {this.getNumThresholds() - 1})),
-                        getTF()
-                            .math
-                            .maximum(
-                                getTF()
-                                    .slice(
-                                        p,
-                                        getTF().constant(new int[] {1}),
-                                        getTF().constant(new int[] {-1})),
-                                getTF().dtypes.cast(getTF().constant(0), p.type()))),
-                getTF()
-                    .onesLike(
-                        getTF()
-                            .slice(
-                                p,
-                                getTF().constant(new int[] {1}),
-                                getTF().constant(new int[] {-1}))));
+        tf.select(
+            tf.math.logicalAnd(
+                tf.math.greater(
+                    tf.slice(
+                        p,
+                        tf.constant(new int[] {0}),
+                        tf.constant(new int[] {this.getNumThresholds() - 1})),
+                    tf.dtypes.cast(tf.constant(0), p.type())),
+                tf.math.greater(
+                    tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
+                    tf.dtypes.cast(tf.constant(0), p.type()))),
+            tf.math.divNoNan(
+                tf.slice(
+                    p,
+                    tf.constant(new int[] {0}),
+                    tf.constant(new int[] {this.getNumThresholds() - 1})),
+                tf.math.maximum(
+                    tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
+                    tf.dtypes.cast(tf.constant(0), p.type()))),
+            tf.onesLike(tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))));
 
     Operand fn1 =
-        getTF()
-            .slice(
-                falseNegatives, getTF().constant(new int[] {1}), getTF().constant(new int[] {-1}));
+        tf.slice(falseNegatives, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
 
     Operand aucTotalPos =
-        getTF()
-            .math
-            .mul(
-                precisionSlope,
-                getTF().math.add(dTP, getTF().math.mul(intercept, getTF().math.log(safePRatio))));
+        tf.math.mul(
+            precisionSlope, tf.math.add(dTP, tf.math.mul(intercept, tf.math.log(safePRatio))));
 
     Operand prAucIncrement =
-        getTF()
-            .math
-            .divNoNan(
-                aucTotalPos,
-                getTF()
-                    .math
-                    .maximum(
-                        getTF().math.add(tp1, fn1),
-                        getTF().dtypes.cast(getTF().constant(0), this.truePositives.type())));
+        tf.math.divNoNan(
+            aucTotalPos,
+            tf.math.maximum(
+                tf.math.add(tp1, fn1), tf.dtypes.cast(tf.constant(0), this.truePositives.type())));
 
     if (this.isMultiLabel()) {
-      Operand byLabelAuc = getTF().reduceSum(prAucIncrement, getTF().constant(0));
+      Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0));
       if (this.getLabelWeights() == null) {
-        return MetricsHelper.mean(getTF(), byLabelAuc);
+        return MetricsHelper.mean(tf, byLabelAuc);
       } else {
-        return getTF()
-            .math
-            .divNoNan(
-                getTF()
-                    .reduceSum(
-                        getTF().math.mul(byLabelAuc, this.getLabelWeights()),
-                        allAxes(getTF(), byLabelAuc)),
-                getTF().reduceSum(getLabelWeights(), allAxes(getTF(), getLabelWeights())));
+        return tf.math.divNoNan(
+            tf.reduceSum(tf.math.mul(byLabelAuc, this.getLabelWeights()), allAxes(tf, byLabelAuc)),
+            tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights())));
       }
     } else {
-      return getTF().reduceSum(prAucIncrement, allAxes(getTF(), prAucIncrement));
+      return tf.reduceSum(prAucIncrement, allAxes(tf, prAucIncrement));
     }
   }
 
@@ -862,13 +812,13 @@ public Operand result() {
 
     if (this.getCurve() == AUCCurve.PR
         && this.getSummationMethod() == AUCSummationMethod.INTERPOLATION) {
+      // This use case is different and is handled separately.
       return this.interpolatePRAuc();
     }
     Ops tf = getTF();
     Operand x;
     Operand y;
-    Operand recall =
-        getTF().math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives));
+    Operand recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives));
 
     if (this.getCurve() == AUCCurve.ROC) {
       x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives));
@@ -890,7 +840,7 @@ public Operand result() {
     switch (this.getSummationMethod()) {
       case INTERPOLATION:
         heights =
-            tf.math.div(tf.math.add(ySlice1, ySlice2), tf.dtypes.cast(tf.constant(2), y.type()));
+            tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type()));
         break;
       case MINORING:
         heights = tf.math.minimum(ySlice1, ySlice2);
@@ -915,6 +865,7 @@ public Operand result() {
       if (this.getLabelWeights() == null) {
         return MetricsHelper.mean(tf, byLabelAuc);
       } else {
+        //Weighted average of the label AUCs.
         return tf.math.divNoNan(
             tf.reduceSum(
                 tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())),
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java
index ee87cebfa48..bd536f16b29 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java
@@ -75,7 +75,7 @@ public class Precision extends Metric {
    * @param type the data type for the variables
    */
   public Precision(Ops tf, long seed, Class type) {
-    this(tf, null, new float[] {DEFAULT_THRESHOLD}, null, null, seed, type);
+    this(tf, null, null, null, null, seed, type);
   }
 
   /**
@@ -90,7 +90,7 @@ public Precision(Ops tf, long seed, Class type) {
    * @param type the data type for the variables
    */
   public Precision(Ops tf, String name, long seed, Class type) {
-    this(tf, name, new float[] {DEFAULT_THRESHOLD}, null, null, seed, type);
+    this(tf, name, null, null, null, seed, type);
   }
 
   /**
@@ -297,23 +297,23 @@ public List updateStateList(
       Operand labels,
       Operand predictions,
       Operand sampleWeights) {
-
+    Ops tf = getTF();
     Map> confusionMatrix = new HashMap<>();
     confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives);
     confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, falsePositives);
 
-    Operand tPredictions = cast(getTF(), predictions, type);
-    Operand tLabels = cast(getTF(), labels, type);
-    Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null;
+    Operand tPredictions = cast(tf, predictions, type);
+    Operand tLabels = cast(tf, labels, type);
+    Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
 
     return new ArrayList(
         MetricsHelper.updateConfusionMatrixVariables(
-            getTF(),
+            tf,
             confusionMatrix,
             Collections.EMPTY_MAP,
             tLabels,
             tPredictions,
-            thresholds,
+            tf.constant(thresholds),
             topK,
             classId,
             tSampleWeights,
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java
index e1eebb98f77..54e9de0d9cf 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java
@@ -328,24 +328,24 @@ public List updateStateList(
       Operand labels,
       Operand predictions,
       Operand sampleWeights) {
-
+    Ops tf = getTF();
     Map> confusionMatrix = new HashMap<>();
     confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives);
     confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives);
 
-    Operand tPredictions = cast(getTF(), predictions, type);
-    Operand tLabels = cast(getTF(), labels, type);
-    Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null;
+    Operand tPredictions = cast(tf, predictions, type);
+    Operand tLabels = cast(tf, labels, type);
+    Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
 
     return MetricsHelper.updateConfusionMatrixVariables(
-        getTF(),
+        tf,
         confusionMatrix,
         Collections.EMPTY_MAP,
         tLabels,
         tPredictions,
-        this.thresholds,
-        this.topK,
-        this.classId,
+        tf.constant(thresholds),
+        topK,
+        classId,
         tSampleWeights,
         false,
         null);
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java
index c9e762d05d4..31e88b6bb31 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java
@@ -140,9 +140,10 @@ public List updateStateList(
       Operand labels,
       Operand predictions,
       Operand sampleWeights) {
-    Operand tLabels = cast(getTF(), labels, type);
-    Operand tPredictions = cast(getTF(), predictions, type);
-    Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null;
+    Ops tf = getTF();
+    Operand tLabels = cast(tf, labels, type);
+    Operand tPredictions = cast(tf, predictions, type);
+    Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
     return new ArrayList<>(
         MetricsHelper.updateConfusionMatrixVariables(
             getTF(),
@@ -150,7 +151,7 @@ public List updateStateList(
             Collections.singletonMap(confusionMatrixCond, initializer),
             tLabels,
             tPredictions,
-            thresholds,
+            tf.constant(thresholds),
             null,
             null,
             tSampleWeights,
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
index 0be0a7a572a..45a236ef814 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
@@ -15,6 +15,7 @@
 package org.tensorflow.framework.metrics.impl;
 
 import org.tensorflow.Operand;
+import org.tensorflow.Session;
 import org.tensorflow.framework.losses.impl.LossTuple;
 import org.tensorflow.framework.losses.impl.LossesHelper;
 import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException;
@@ -26,10 +27,7 @@
 import org.tensorflow.op.core.*;
 import org.tensorflow.op.math.Mean;
 import org.tensorflow.op.nn.TopK;
-import org.tensorflow.types.TBool;
-import org.tensorflow.types.TFloat64;
-import org.tensorflow.types.TInt32;
-import org.tensorflow.types.TInt64;
+import org.tensorflow.types.*;
 import org.tensorflow.types.family.TIntegral;
 import org.tensorflow.types.family.TNumber;
 
@@ -277,6 +275,7 @@ public static List assertShapes(
     return updateOperations;
   }
 
+
   /**
    * Returns an op to update the given confusion matrix variables.
    *
@@ -335,7 +334,7 @@ public static  List updateConfusionMatrixVariables(
       Map> varInitializers,
       Operand labels,
       Operand predictions,
-      float[] thresholds,
+      Operand thresholds,
       Integer topK,
       Integer classId,
       Operand sampleWeight,
@@ -349,68 +348,65 @@ public static  List updateConfusionMatrixVariables(
       return Collections.EMPTY_LIST;
     }
 
-    Operand lLabels = labels;
-    Operand lPredictions = predictions;
-    Operand lSampleWeight = sampleWeight;
+    Operand tLabels = labels;
+    Operand tPredictions = predictions;
+    Operand tSampleWeight = sampleWeight;
 
-    Operand numThresholds;
+    Operand numThresholds =
+        tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar()));
     Operand oneThresh;
     if (multiLabel) {
-      numThresholds = tf.shape.size(lLabels, tf.constant(0));
-      oneThresh = tf.math.equal(tf.constant(1), tf.constant(thresholds.length));
+      oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds));
     } else {
       // TODO handle Ragged Tensors????
       // [y_pred,
       //    y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
       //                                                   sampleWeights)
-      numThresholds = tf.constant(thresholds.length);
       oneThresh = tf.constant(true);
+      numThresholds = tf.shape.size(tf.shape(thresholds));
     }
 
     List controlOps = new ArrayList<>();
-    Operand axes = allAxes(tf, lPredictions);
+    Operand axes = allAxes(tf, tPredictions);
     controlOps.add(
         tf.withSubScope("updateConfusionMatrixVariables-1")
             .assertThat(
                 tf.reduceAll(
                     tf.math.greaterEqual(
-                        lPredictions, cast(tf, tf.constant(0), lPredictions.type())),
+                        tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
                     axes),
                 Collections.singletonList(tf.constant("predictions must be >= 0"))));
     controlOps.add(
         tf.withSubScope("updateConfusionMatrixVariables-2")
             .assertThat(
                 tf.reduceAll(
-                    tf.math.lessEqual(lPredictions, cast(tf, tf.constant(1), lPredictions.type())),
+                    tf.math.lessEqual(tPredictions, cast(tf, tf.constant(1), tPredictions.type())),
                     axes),
                 Collections.singletonList(tf.constant("predictions must be <= 1"))));
 
     LossTuple result =
-        LossesHelper.squeezeOrExpandDimensions(tf, lLabels, lPredictions, lSampleWeight);
-    lPredictions = result.getTarget();
-    lLabels = result.getLabels();
-    lSampleWeight = result.getSampleWeights();
+        LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions, tSampleWeight);
+    tPredictions = result.getTarget();
+    tLabels = result.getLabels();
+    tSampleWeight = result.getSampleWeights();
 
-    if (!lPredictions.shape().isCompatibleWith(lLabels.shape()))
+    if (!tPredictions.shape().isCompatibleWith(tLabels.shape()))
       throw new IllegalArgumentException(
           String.format(
               "Shapes %s and %s are incompatible)",
-              lPredictions.shape().toString(), lLabels.asOutput().shape().toString()));
+              tPredictions.shape().toString(), tLabels.asOutput().shape().toString()));
 
     if (topK != null) {
-      lPredictions = filterTopK(tf, lPredictions, topK);
+      tPredictions = filterTopK(tf, tPredictions, topK);
     }
 
     if (classId != null) {
-      lLabels = tf.squeeze(tf.gather(lLabels, tf.constant(new int[] {classId}), tf.constant(1)));
-      lPredictions =
-          tf.squeeze(tf.gather(lPredictions, tf.constant(new int[] {classId}), tf.constant(1)));
-      lLabels = tf.expandDims(lLabels, tf.constant(0));
-      lPredictions = tf.expandDims(lPredictions, tf.constant(0));
+      tLabels = tf.gather(tLabels, tf.constant(new int[] {classId}), tf.constant(1));
+      tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classId}), tf.constant(1));
     }
-    org.tensorflow.op.core.Shape predShape = tf.shape(lPredictions);
+    org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions);
     Operand numPredictions =
-        tf.reshape(tf.shape.size(lPredictions, tf.constant(0)), tf.constant(Shape.scalar()));
+        tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar()));
     Operand numLabels =
         tf.select(
             tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)),
@@ -424,50 +420,52 @@ lPredictions, cast(tf, tf.constant(0), lPredictions.type())),
     Operand predictionsExtraDim;
     Operand labelsExtraDim;
     if (multiLabel) {
-      predictionsExtraDim = tf.expandDims(lPredictions, tf.constant(0));
-      labelsExtraDim = tf.expandDims(cast(tf, lLabels, TBool.class), tf.constant(0));
+      predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0));
+      labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0));
     } else {
-      predictionsExtraDim = tf.reshape(lPredictions, tf.constant(Shape.of(1, -1)));
-      labelsExtraDim = tf.reshape(cast(tf, lLabels, TBool.class), tf.constant(Shape.of(1, -1)));
+      predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1)));
+      labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1)));
     }
     List> threshPretileShape;
     List> threshTiles;
     List> dataTiles;
     if (multiLabel) {
       threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1));
-
       threshTiles = Arrays.asList(tf.constant(1), numPredictions, threshLabelTile);
       dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1));
     } else {
-      threshPretileShape = Arrays.asList(numThresholds, tf.constant(-1));
+      threshPretileShape =
+          Arrays.asList(tf.reshape(numThresholds, tf.constant(Shape.scalar())), tf.constant(-1));
       Operand mul = tf.math.mul(numPredictions, numLabels);
       threshTiles = Arrays.asList(tf.constant(1), mul);
       dataTiles = Arrays.asList(numThresholds, tf.constant(1));
     }
 
     Operand thresholdsReshaped =
-        tf.reshape(
-            cast(tf, tf.constant(thresholds), predictions.type()), tf.stack(threshPretileShape));
+        tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape));
     Operand threshTilesShape = tf.stack(threshTiles);
     Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape);
-    Operand predsTiled = tf.tile(predictionsExtraDim, tf.stack(dataTiles));
+    Operand stackedTiles = tf.stack(dataTiles);
+
+    Operand predsTiled = tf.tile(predictionsExtraDim, stackedTiles);
 
     // Compare predictions and threshold.
     Operand predIsPos = tf.math.greater(predsTiled, threshTiled);
     // Tile labels by number of thresholds
     Operand labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles));
     Operand weightsTiled;
-    if (lSampleWeight != null) {
-      lSampleWeight =
-          tf.broadcastTo(cast(tf, lSampleWeight, predictions.type()), tf.shape(lPredictions));
-      weightsTiled = tf.tile(tf.reshape(lSampleWeight, tf.stack(threshTiles)), tf.stack(dataTiles));
+    if (tSampleWeight != null) {
+      tSampleWeight =
+          tf.broadcastTo(cast(tf, tSampleWeight, predictions.type()), tf.shape(tPredictions));
+      weightsTiled = tf.tile(tf.reshape(tSampleWeight, tf.stack(threshTiles)), tf.stack(dataTiles));
     } else {
       weightsTiled = null;
     }
 
     if (labelWeights != null) {
       Operand lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0));
-      lLabelWeights = tf.broadcastTo(cast(tf, lLabelWeights, labelWeights.type()), lPredictions);
+
+      lLabelWeights = tf.broadcastTo(lLabelWeights, tPredictions);
       Operand labelWeightsTiled =
           tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles));
       if (weightsTiled == null) {
@@ -520,6 +518,7 @@ lPredictions, cast(tf, tf.constant(0), lPredictions.type())),
     return controlOps;
   }
 
+
   /**
    * Creates an Operand that adds the values by taking the logical and of labels and predictions to
    * the specified confusion matrix variable.
@@ -700,57 +699,57 @@ public static  Operand confusionMatrix(
               predictions.shape().toString(), labels.shape().toString()));
     tf = tf.withSubScope("confusionMatrix");
     LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, predictions, labels, null);
-    Operand lPredictions = cast(tf, ops.getTarget(), TInt64.class);
-    Operand lLabels = cast(tf, ops.getLabels(), TInt64.class);
+    Operand tPredictions = cast(tf, ops.getTarget(), TInt64.class);
+    Operand tLabels = cast(tf, ops.getLabels(), TInt64.class);
 
     List labelControls = new ArrayList<>();
     List predictionControls = new ArrayList<>();
 
     labelControls.add(
         tf.assertThat(
-            tf.reduceAny(tf.math.greaterEqual(lLabels, tf.constant(0L)), allAxes(tf, lLabels)),
+            tf.reduceAny(tf.math.greaterEqual(tLabels, tf.constant(0L)), allAxes(tf, tLabels)),
             Collections.singletonList(tf.constant("`labels` contains negative values"))));
 
     predictionControls.add(
         tf.assertThat(
             tf.reduceAny(
-                tf.math.greaterEqual(lPredictions, tf.constant(0L)), allAxes(tf, lPredictions)),
+                tf.math.greaterEqual(tPredictions, tf.constant(0L)), allAxes(tf, tPredictions)),
             Collections.singletonList(tf.constant("`predictions` contains negative values"))));
     if (numClasses == null) {
       numClasses =
           tf.math.maximum(
-              tf.reduceMax(lPredictions, allAxes(tf, lPredictions)),
-              tf.reduceMax(lLabels, allAxes(tf, lLabels)));
+              tf.reduceMax(tPredictions, allAxes(tf, tPredictions)),
+              tf.reduceMax(tLabels, allAxes(tf, tLabels)));
     } else {
       labelControls.add(
           tf.assertThat(
-              tf.reduceAny(tf.math.less(lLabels, numClasses), allAxes(tf, lLabels)),
+              tf.reduceAny(tf.math.less(tLabels, numClasses), allAxes(tf, tLabels)),
               Collections.singletonList(tf.constant("``labels` out of bounds"))));
       predictionControls.add(
           tf.assertThat(
-              tf.reduceAny(tf.math.less(lPredictions, numClasses), allAxes(tf, lPredictions)),
+              tf.reduceAny(tf.math.less(tPredictions, numClasses), allAxes(tf, tPredictions)),
               Collections.singletonList(tf.constant("``predictions` out of bounds"))));
     }
 
     if (weights != null) {
-      if (!lPredictions.shape().isCompatibleWith(weights.shape())) {
+      if (!tPredictions.shape().isCompatibleWith(weights.shape())) {
         throw new IllegalArgumentException(
             String.format(
                 "Prediction shape %s is not compatible with weights shape %s",
-                lPredictions.shape().toString(), weights.shape().toString()));
+                tPredictions.shape().toString(), weights.shape().toString()));
       }
     }
 
     Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls);
-    lLabels = tfc.identity(lLabels);
+    tLabels = tfc.identity(tLabels);
 
     tfc = tf.withSubScope("confusionMatrixPredicitons").withControlDependencies(predictionControls);
-    lPredictions = tfc.identity(lPredictions);
+    tPredictions = tfc.identity(tPredictions);
 
     Operand shape = tf.stack(Arrays.asList(numClasses, numClasses));
-    Operand indices = tf.stack(Arrays.asList(lLabels, lPredictions), Stack.axis(1L));
+    Operand indices = tf.stack(Arrays.asList(tLabels, tPredictions), Stack.axis(1L));
     Operand values =
-        weights == null ? cast(tf, tf.onesLike(lPredictions), type) : cast(tf, weights, type);
+        weights == null ? cast(tf, tf.onesLike(tPredictions), type) : cast(tf, weights, type);
     SparseTensor cmSparse = new SparseTensor<>(indices, values, shape);
     Operand zeroMatrix = tf.zeros(shape, type);
 
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java
index 377124333bd..84898d8a4d3 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java
@@ -143,10 +143,10 @@ public List updateStateList(
     Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
 
     Map> confusionMatrix = new HashMap<>();
-    confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.getTruePositives());
-    confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.getFalsePositives());
-    confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.getTrueNegatives());
-    confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.getFalseNegatives());
+    confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, getTruePositives());
+    confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, getFalsePositives());
+    confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, getTrueNegatives());
+    confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, getFalseNegatives());
 
     return MetricsHelper.updateConfusionMatrixVariables(
         tf,
@@ -154,7 +154,7 @@ public List updateStateList(
         Collections.EMPTY_MAP,
         tLabels,
         tPredictions,
-        this.getThresholds(),
+        tf.constant(thresholds),
         null,
         null,
         tSampleWeights,
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java
index 35962a568ca..148ca520d3f 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java
@@ -16,6 +16,7 @@
 
 import org.junit.jupiter.api.Test;
 import org.tensorflow.Operand;
+import org.tensorflow.framework.metrics.impl.MetricsHelper;
 import org.tensorflow.framework.utils.TestSession;
 import org.tensorflow.ndarray.Shape;
 import org.tensorflow.op.Op;
@@ -125,7 +126,6 @@ public void testDivByZero() {
       Op update = instance.updateState(labels, predictions, null);
       session.run(update);
       Operand precision = instance.result();
-
       session.evaluate(0, precision);
     }
   }

From 60d513da232f3e208930847c7e7cb04d2ac2c06d Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Tue, 2 Mar 2021 12:01:28 -0500
Subject: [PATCH 52/97] change classId to classIndex

Added comment on Operand numThresholds reshape to scalar.

Added comment to ExtraDims
---
 .../framework/metrics/impl/MetricsHelper.java      | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
index 45a236ef814..302997eb51a 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
@@ -309,7 +309,8 @@ public static List assertShapes(
    *     topK is set)
    * @param topK Optional, indicates that the positive labels should be limited to the top k
    *     predictions, may be null.
-   * @param classId Optional, limits the prediction and labels to the specified class
+   * @param classIndex Optional, limits the prediction and labels to the specified class.
+   *                The classIndex is and integer representing a specific classification class's input data..
    * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as
    *     labels, and must be broadcast to labels (i.e., all dimensions
    *     must be either 1, or the same as the corresponding labels
@@ -336,7 +337,7 @@ public static  List updateConfusionMatrixVariables(
       Operand predictions,
       Operand thresholds,
       Integer topK,
-      Integer classId,
+      Integer classIndex,
       Operand sampleWeight,
       boolean multiLabel,
       Operand labelWeights) {
@@ -352,6 +353,7 @@ public static  List updateConfusionMatrixVariables(
     Operand tPredictions = predictions;
     Operand tSampleWeight = sampleWeight;
 
+    // reshape to scalar for operations later.
     Operand numThresholds =
         tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar()));
     Operand oneThresh;
@@ -363,7 +365,6 @@ public static  List updateConfusionMatrixVariables(
       //    y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
       //                                                   sampleWeights)
       oneThresh = tf.constant(true);
-      numThresholds = tf.shape.size(tf.shape(thresholds));
     }
 
     List controlOps = new ArrayList<>();
@@ -400,9 +401,9 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
       tPredictions = filterTopK(tf, tPredictions, topK);
     }
 
-    if (classId != null) {
-      tLabels = tf.gather(tLabels, tf.constant(new int[] {classId}), tf.constant(1));
-      tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classId}), tf.constant(1));
+    if (classIndex != null) {
+      tLabels = tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(1));
+      tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1));
     }
     org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions);
     Operand numPredictions =
@@ -417,6 +418,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())),
                 tf.constant(0)));
     Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1));
 
+    // The ExtraDims are added so the operands of the tile operations later on are compatible.
     Operand predictionsExtraDim;
     Operand labelsExtraDim;
     if (multiLabel) {

From ea2e3b1f3d9bad39b64b4328370c1998c613df75 Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Tue, 2 Mar 2021 12:02:32 -0500
Subject: [PATCH 53/97] fix spurious "this.".

---
 .../org/tensorflow/framework/metrics/AUC.java | 78 +++++++++----------
 1 file changed, 39 insertions(+), 39 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
index 8a31dfd3fce..1269f3453f3 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
@@ -613,7 +613,7 @@ private Map> build(Shape shape) {
     }
     Ops tf = getTF();
 
-    if (this.isMultiLabel()) {
+    if (isMultiLabel()) {
       if (shape == null) {
         throw new IllegalArgumentException("For multiLabel, a shape must be provided");
       }
@@ -622,14 +622,14 @@ private Map> build(Shape shape) {
             String.format(
                 "labels must have rank=2 when multiLabel is true. Found rank %d.",
                 shape.numDimensions()));
-      this.numLabels = (int) shape.size(1);
-      variableShape = Shape.of(this.numThresholds, this.numLabels);
+      numLabels = (int) shape.size(1);
+      variableShape = Shape.of(numThresholds, numLabels);
     } else {
-      variableShape = Shape.of(this.numThresholds);
+      variableShape = Shape.of(numThresholds);
     }
 
     // Create metric variables
-    Zeros zeros = new Zeros<>(getTF());
+    Zeros zeros = new Zeros<>(tf);
     Operand zero = zeros.call(tf.constant(variableShape), type);
     if (truePositives == null) {
       truePositives = tf.withName(getTruePositivesName()).variable(zero);
@@ -651,7 +651,7 @@ private Map> build(Shape shape) {
       initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, tf.assign(falseNegatives, zero));
     }
 
-    this.initialized = true;
+    initialized = true;
     return initializers;
   }
 
@@ -668,39 +668,39 @@ public List updateStateList(
     Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null;
     List updateOperations = new ArrayList<>();
     Map> varInitializers = Collections.EMPTY_MAP;
-    if (!this.initialized) {
+    if (!initialized) {
       varInitializers = build(tPredictions.shape());
     }
-    if (this.isMultiLabel() || this.getLabelWeights() != null) {
+    if (isMultiLabel() || getLabelWeights() != null) {
       // labels should have shape (number of examples, number of labels).
       List> symbols = new ArrayList<>();
       symbols.add(new SymbolicShape<>(tLabels, "N", "L"));
-      if (this.isMultiLabel()) {
+      if (isMultiLabel()) {
         // TP, TN, FP, and FN should all have shape
         //(number of thresholds, number of labels).
-        symbols.add(new SymbolicShape<>(this.truePositives, "T", "L"));
-        symbols.add(new SymbolicShape<>(this.falsePositives, "T", "L"));
-        symbols.add(new SymbolicShape<>(this.trueNegatives, "T", "L"));
-        symbols.add(new SymbolicShape<>(this.falseNegatives, "T", "L"));
+        symbols.add(new SymbolicShape<>(truePositives, "T", "L"));
+        symbols.add(new SymbolicShape<>(falsePositives, "T", "L"));
+        symbols.add(new SymbolicShape<>(trueNegatives, "T", "L"));
+        symbols.add(new SymbolicShape<>(falseNegatives, "T", "L"));
       }
-      if (this.getLabelWeights() != null) {
-        symbols.add(new SymbolicShape<>(this.getLabelWeights(), "L", ""));
+      if (getLabelWeights() != null) {
+        symbols.add(new SymbolicShape<>(getLabelWeights(), "L", ""));
       }
       updateOperations.addAll(
           MetricsHelper.assertShapes(tf, symbols, "Number of labels is not consistent."));
     }
 
     Map> confusionMatrix = new HashMap<>();
-    confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives);
-    confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, this.falsePositives);
-    confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, this.trueNegatives);
-    confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives);
+    confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives);
+    confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, falsePositives);
+    confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, trueNegatives);
+    confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, falseNegatives);
 
     // Only forward labelWeights to update_confusion_matrix_variables when
     // multiLabel is false. Otherwise the averaging of individual label AUCs is
     // handled in AUC.result
-    if (this.isMultiLabel()) {
-      this.labelWeights = null;
+    if (isMultiLabel()) {
+      labelWeights = null;
     }
     updateOperations.addAll(
         MetricsHelper.updateConfusionMatrixVariables(
@@ -730,7 +730,7 @@ private Operand interpolatePRAuc() {
         tf.slice(
             truePositives,
             tf.constant(new int[] {0}),
-            tf.constant(new int[] {this.getNumThresholds() - 1}));
+            tf.constant(new int[] {getNumThresholds() - 1}));
     // truePositives[1:]
     Operand tp1 =
         tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
@@ -744,7 +744,7 @@ private Operand interpolatePRAuc() {
             tf.slice(
                 p,
                 tf.constant(new int[] {0}),
-                tf.constant(new int[] {this.getNumThresholds() - 1})),
+                tf.constant(new int[] {getNumThresholds() - 1})),
             tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})));
 
     Operand precisionSlope =
@@ -764,7 +764,7 @@ private Operand interpolatePRAuc() {
                     tf.slice(
                         p,
                         tf.constant(new int[] {0}),
-                        tf.constant(new int[] {this.getNumThresholds() - 1})),
+                        tf.constant(new int[] {getNumThresholds() - 1})),
                     tf.dtypes.cast(tf.constant(0), p.type())),
                 tf.math.greater(
                     tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
@@ -773,7 +773,7 @@ private Operand interpolatePRAuc() {
                 tf.slice(
                     p,
                     tf.constant(new int[] {0}),
-                    tf.constant(new int[] {this.getNumThresholds() - 1})),
+                    tf.constant(new int[] {getNumThresholds() - 1})),
                 tf.math.maximum(
                     tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})),
                     tf.dtypes.cast(tf.constant(0), p.type()))),
@@ -790,15 +790,15 @@ private Operand interpolatePRAuc() {
         tf.math.divNoNan(
             aucTotalPos,
             tf.math.maximum(
-                tf.math.add(tp1, fn1), tf.dtypes.cast(tf.constant(0), this.truePositives.type())));
+                tf.math.add(tp1, fn1), tf.dtypes.cast(tf.constant(0), truePositives.type())));
 
-    if (this.isMultiLabel()) {
+    if (isMultiLabel()) {
       Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0));
-      if (this.getLabelWeights() == null) {
+      if (getLabelWeights() == null) {
         return MetricsHelper.mean(tf, byLabelAuc);
       } else {
         return tf.math.divNoNan(
-            tf.reduceSum(tf.math.mul(byLabelAuc, this.getLabelWeights()), allAxes(tf, byLabelAuc)),
+            tf.reduceSum(tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, byLabelAuc)),
             tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights())));
       }
     } else {
@@ -810,17 +810,17 @@ private Operand interpolatePRAuc() {
   @Override
   public Operand result() {
 
-    if (this.getCurve() == AUCCurve.PR
-        && this.getSummationMethod() == AUCSummationMethod.INTERPOLATION) {
+    if (getCurve() == AUCCurve.PR
+        && getSummationMethod() == AUCSummationMethod.INTERPOLATION) {
       // This use case is different and is handled separately.
-      return this.interpolatePRAuc();
+      return interpolatePRAuc();
     }
     Ops tf = getTF();
     Operand x;
     Operand y;
     Operand recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives));
 
-    if (this.getCurve() == AUCCurve.ROC) {
+    if (getCurve() == AUCCurve.ROC) {
       x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives));
       y = recall;
     } else { // AUCCurve.PR
@@ -832,12 +832,12 @@ public Operand result() {
     // y[:self.numThresholds - 1]
     Operand ySlice1 =
         tf.slice(
-            y, tf.constant(new int[] {0}), tf.constant(new int[] {this.getNumThresholds() - 1}));
+            y, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1}));
     // y[1:]
     Operand ySlice2 = tf.slice(y, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
 
     Operand heights = null;
-    switch (this.getSummationMethod()) {
+    switch (getSummationMethod()) {
       case INTERPOLATION:
         heights =
             tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type()));
@@ -850,19 +850,19 @@ public Operand result() {
         break;
     }
 
-    if (this.isMultiLabel()) {
+    if (isMultiLabel()) {
       Operand riemannTerms =
           tf.math.mul(
               tf.math.sub(
                   tf.slice(
                       x,
                       tf.constant(new int[] {0}),
-                      tf.constant(new int[] {this.getNumThresholds() - 1})),
+                      tf.constant(new int[] {getNumThresholds() - 1})),
                   tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))),
               heights);
       Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0));
 
-      if (this.getLabelWeights() == null) {
+      if (getLabelWeights() == null) {
         return MetricsHelper.mean(tf, byLabelAuc);
       } else {
         //Weighted average of the label AUCs.
@@ -875,7 +875,7 @@ public Operand result() {
     } else {
       Operand slice1 =
           tf.slice(
-              x, tf.constant(new int[] {0}), tf.constant(new int[] {this.getNumThresholds() - 1}));
+              x, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1}));
       Operand slice2 = tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1}));
       Operand sub = tf.math.sub(slice1, slice2);
       Operand operand = tf.math.mul(sub, heights);

From ca9e3959650c87d65e459c66bfd3460bb75df80b Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Wed, 3 Mar 2021 12:31:11 -0500
Subject: [PATCH 54/97] Remove references to keras in javadoc.

---
 .../java/org/tensorflow/framework/metrics/AUC.java    |  2 +-
 .../framework/metrics/SparseCategoricalAccuracy.java  | 11 +----------
 2 files changed, 2 insertions(+), 11 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
index 1269f3453f3..5ac07e98451 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
@@ -58,7 +58,7 @@
  * 

Usage:
* *

- * AUC m = new  tf.keras.metrics.AUC( tf, 3);
+ * AUC m = new  org.tensorflow.framework.metrcis.AUC( tf, 3);
  * m.updateState( tf.constant(new float[] {0, 0, 1,1}),
  *          tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
  *
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java
index 7034861d8d2..0d18c1e2dcb 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java
@@ -44,7 +44,7 @@
  * 

Usage: * *

- * SparseCategoricalAccuracy m = new tf.keras.metrics.SparseCategoricalAccuracy();
+ * SparseCategoricalAccuracy m = new org.tensorflow.framework.metrcis.SparseCategoricalAccuracy();
  * m.update_state(tf.constant(new float[][] {{2}, {1}},
  *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, [{0.05f, 0.95f, 0f}});
  * Operand<TFloat32> result = m.result();
@@ -63,15 +63,6 @@
  * 0.3
  * 
* - *

Usage with tf.keras API: - * - *

- * Model model = new tf.keras. models.Model(inputs, outputs);
- * model.compile(
- *     "sgd",
- *     loss="mse",
- *     metrics=["sparse_categorical_accuracy"]);
- * 
* * @param The data type for the metric result */ From 817430f0b5eba0b8a391f8a729b1047454fc66a4 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Mar 2021 12:51:53 -0500 Subject: [PATCH 55/97] Fix javadoc --- .../java/org/tensorflow/framework/constraints/Constraint.java | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index d3094b5e9e9..306361959bf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -42,6 +42,7 @@ public Constraint(Ops tf) { * * @param weights the weights * @return the constrained weights + * @param the data type for weights and results. */ public abstract Operand call(Operand weights); From b440c634197931baf5f78b277d6560e8bea08810 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Mar 2021 13:27:32 -0500 Subject: [PATCH 56/97] Reformat code and fix labelWeights argument in call to updateConfusionMatrixVariables --- .../org/tensorflow/framework/metrics/AUC.java | 61 +++++++++++-------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 5ac07e98451..cd83bbeb26d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -209,7 +209,7 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) { this( tf, null, - DEFAULT_NUM_THRESHOLDS, + DEFAULT_NUM_THRESHOLDS, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, thresholds, @@ -264,7 +264,7 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { this( tf, name, - DEFAULT_NUM_THRESHOLDS, + DEFAULT_NUM_THRESHOLDS, AUCCurve.ROC, AUCSummationMethod.INTERPOLATION, thresholds, @@ -322,7 +322,7 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C this( tf, name, - DEFAULT_NUM_THRESHOLDS, + DEFAULT_NUM_THRESHOLDS, curve, AUCSummationMethod.INTERPOLATION, thresholds, @@ -378,7 +378,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) this( tf, null, - DEFAULT_NUM_THRESHOLDS, + DEFAULT_NUM_THRESHOLDS, curve, AUCSummationMethod.INTERPOLATION, thresholds, @@ -435,7 +435,17 @@ public AUC( AUCSummationMethod summationMethod, long seed, Class type) { - this(tf, null, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type); + this( + tf, + null, + DEFAULT_NUM_THRESHOLDS, + curve, + summationMethod, + thresholds, + false, + null, + seed, + type); } /** @@ -487,7 +497,17 @@ public AUC( AUCSummationMethod summationMethod, long seed, Class type) { - this(tf, name, DEFAULT_NUM_THRESHOLDS, curve, summationMethod, thresholds, false, null, seed, type); + this( + tf, + name, + DEFAULT_NUM_THRESHOLDS, + curve, + summationMethod, + thresholds, + false, + null, + seed, + type); } /** @@ -677,7 +697,7 @@ public List updateStateList( symbols.add(new SymbolicShape<>(tLabels, "N", "L")); if (isMultiLabel()) { // TP, TN, FP, and FN should all have shape - //(number of thresholds, number of labels). + // (number of thresholds, number of labels). symbols.add(new SymbolicShape<>(truePositives, "T", "L")); symbols.add(new SymbolicShape<>(falsePositives, "T", "L")); symbols.add(new SymbolicShape<>(trueNegatives, "T", "L")); @@ -699,9 +719,6 @@ public List updateStateList( // Only forward labelWeights to update_confusion_matrix_variables when // multiLabel is false. Otherwise the averaging of individual label AUCs is // handled in AUC.result - if (isMultiLabel()) { - labelWeights = null; - } updateOperations.addAll( MetricsHelper.updateConfusionMatrixVariables( tf, @@ -714,7 +731,7 @@ public List updateStateList( null, tSampleWeights, isMultiLabel(), - getLabelWeights())); + isMultiLabel() ? null : getLabelWeights())); return updateOperations; } @@ -742,9 +759,7 @@ private Operand interpolatePRAuc() { Operand dP = tf.math.sub( tf.slice( - p, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})), + p, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})), tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))); Operand precisionSlope = @@ -771,9 +786,7 @@ private Operand interpolatePRAuc() { tf.dtypes.cast(tf.constant(0), p.type()))), tf.math.divNoNan( tf.slice( - p, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})), + p, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})), tf.math.maximum( tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})), tf.dtypes.cast(tf.constant(0), p.type()))), @@ -810,8 +823,7 @@ private Operand interpolatePRAuc() { @Override public Operand result() { - if (getCurve() == AUCCurve.PR - && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { + if (getCurve() == AUCCurve.PR && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { // This use case is different and is handled separately. return interpolatePRAuc(); } @@ -831,16 +843,14 @@ && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { // Find the rectangle heights based on `summationMethod`. // y[:self.numThresholds - 1] Operand ySlice1 = - tf.slice( - y, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); + tf.slice(y, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); // y[1:] Operand ySlice2 = tf.slice(y, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); Operand heights = null; switch (getSummationMethod()) { case INTERPOLATION: - heights = - tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type())); + heights = tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type())); break; case MINORING: heights = tf.math.minimum(ySlice1, ySlice2); @@ -865,7 +875,7 @@ && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { if (getLabelWeights() == null) { return MetricsHelper.mean(tf, byLabelAuc); } else { - //Weighted average of the label AUCs. + // Weighted average of the label AUCs. return tf.math.divNoNan( tf.reduceSum( tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())), @@ -874,8 +884,7 @@ && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { } else { Operand slice1 = - tf.slice( - x, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); + tf.slice(x, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); Operand slice2 = tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); Operand sub = tf.math.sub(slice1, slice2); Operand operand = tf.math.mul(sub, heights); From 021df65c251eae6576f21c390a06835292ca348b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Mar 2021 13:28:50 -0500 Subject: [PATCH 57/97] Reformat code add code comments and change update_xx (update_fn) to updateXX (updateFN) to eliminate snake case. --- .../framework/metrics/impl/MetricsHelper.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 302997eb51a..3d4a2c8dc4f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -413,6 +413,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), tf.constant(1), tf.reduceProd( + // take all but the first dimension tf.shape.takeLast( predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); @@ -479,21 +480,21 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Map loopVars = new HashMap<>(); loopVars.put(ConfusionMatrixEnum.TRUE_POSITIVES, new Operand[] {labelIsPos, predIsPos}); - Variable update_tn = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES); - Variable update_fp = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES); - Variable update_fn = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES); + Variable updateTN = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES); + Variable updateFP = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES); + Variable updateFN = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES); Operand predIsNeg = null; Operand labelIsNeg; - if (update_fn != null || update_tn != null) { + if (updateFN != null || updateTN != null) { predIsNeg = tf.math.logicalNot(predIsPos); loopVars.put(ConfusionMatrixEnum.FALSE_NEGATIVES, new Operand[] {labelIsPos, predIsNeg}); } - if (update_fp != null || update_tn != null) { + if (updateFP != null || updateTN != null) { labelIsNeg = tf.math.logicalNot(labelIsPos); loopVars.put(ConfusionMatrixEnum.FALSE_POSITIVES, new Operand[] {labelIsNeg, predIsPos}); - if (update_tn != null) { + if (updateTN != null) { loopVars.put(ConfusionMatrixEnum.TRUE_NEGATIVES, new Operand[] {labelIsNeg, predIsNeg}); } } From 4df4a80306815e2cee04532267f923a5d828bd2c Mon Sep 17 00:00:00 2001 From: deansher Date: Wed, 3 Mar 2021 08:39:28 -0500 Subject: [PATCH 58/97] Added javadocs and internal docs to AUC.java and MetricsHelper.java --- .../org/tensorflow/framework/metrics/AUC.java | 65 ++++++++++++++++--- .../framework/metrics/impl/MetricsHelper.java | 19 ++++-- 2 files changed, 71 insertions(+), 13 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index cd83bbeb26d..cae67dbd4f0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -106,12 +106,46 @@ public class AUC extends Metric { private final String falseNegativesName; private final Map> initializers = new HashMap<>(); private final Class type; + + /** + * The size of the label dimension. + */ private Integer numLabels; + private Operand labelWeights; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ private Variable truePositives; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ private Variable falsePositives; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ private Variable trueNegatives; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ private Variable falseNegatives; + private boolean initialized; /** @@ -515,22 +549,24 @@ public AUC( * * @param tf The TensorFlow Ops * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. - * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values - * must be > 1. + * @param numThresholds the number of thresholds to use when discretizing the roc curve. + * This includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link * AUCCurve#PR} for the Precision-Recall-curve. * @param summationMethod Specifies the Riemann summation method used * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, - * the numThresholds parameter is ignored. Values should be in [0, 1]. + * the numThresholds parameter is ignored. Values should be in [0, 1]. This method + * automatically brackets the provided thresholds with a (-{@link #EPSILON}) + * below and a (1 + {@link #EPSILON}) above. * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein * AUC is computed separately for each label and then averaged across labels, or (when false) * if the data should be flattened into a single label before AUC computation. In the latter * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an * individual data point. Should be set to false for multi-class data. * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When - * multi_label is True, the weights are applied to the individual label AUCs when they are - * averaged to produce the multi-label AUC. When it's false, they are used to weight the - * individual label predictions in computing the confusion matrix on the flattened data. + * multiLabel is true, the weights are applied to the individual label AUCs when + * they are averaged to produce the multi-label AUC. When it's false, they are used to weight + * the individual label predictions in computing the confusion matrix on the flattened data. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. @@ -595,7 +631,7 @@ public AUC( System.arraycopy(thresholds, 0, this.thresholds, 1, thresholds.length); this.thresholds[this.numThresholds - 1] = 1 + EPSILON; - // # Handle multilabel arguments. + // Handle multilabel arguments. if (labelWeights != null) { // assert that labelWeights are non-negative. @@ -675,7 +711,20 @@ private Map> build(Shape shape) { return initializers; } - /** {@inheritDoc} */ + /** + * Creates a List of Operations to update the metric state based on labels and predictions. + * + * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more + * class dimensions, and L1 is a potential extra dimension of size 1 that + * would be squeezed. Will be cast to T. If + * {@link #multiLabel} or if {@link #labelWeights} != null, + * then Cx must be a single dimension. + * @param predictions the predictions shape (N, Cx, P1?). Will be cast to T. + * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to + * T. + * + * @return a List of Operations to update the metric state + */ @Override @SuppressWarnings("unchecked") public List updateStateList( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 3d4a2c8dc4f..f38d0896a5f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -299,18 +299,25 @@ public static List assertShapes( * * @param tf the TensorFlow Ops * @param variablesToUpdate Map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. + * corresponding variables to update as values. If multiLabel is + * false then all shapes are (T), where T is the number of thresholds. If + * multiLabel is true then all shapes are (T, C0), where C0 is the number + * of classes. * @param varInitializers Map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to initializer the corresponding variables from * variablesToUpdate. * @param labels the labels, will be cast to {@link TBool} - * @param predictions the predictions whose values are in the range [0, 1]. + * shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more + * class dimensions, and L1 is a potential extra dimension of size 1 that + * would be squeezed. If multiLabel or if + * labelWeights != null, then Cx must be a single dimension. + * @param predictions the predictions shape (N, Cx, P1?). * @param thresholds thresholds in `the range [0, 1], or {@link #NEG_INF} (used when * topK is set) - * @param topK Optional, indicates that the positive labels should be limited to the top k - * predictions, may be null. + * @param topK Optional, used only if multiLabel, indicates that only the top k + * predictions should be considered. May be null. * @param classIndex Optional, limits the prediction and labels to the specified class. - * The classIndex is and integer representing a specific classification class's input data.. + * The classIndex is an integer index into the first dimension of Cx. * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as * labels, and must be broadcast to labels (i.e., all dimensions * must be either 1, or the same as the corresponding labels @@ -356,6 +363,8 @@ public static List updateConfusionMatrixVariables( // reshape to scalar for operations later. Operand numThresholds = tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); + + // true if we will process thresholds as one-dimensional (possibly because we flatten them) Operand oneThresh; if (multiLabel) { oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds)); From 99df6b4fee0ffb45d3cec0adc96550682bf83b0c Mon Sep 17 00:00:00 2001 From: deansher Date: Fri, 5 Mar 2021 07:52:07 -0500 Subject: [PATCH 59/97] Added internal docs to MetricsHelper.java --- .../framework/metrics/impl/MetricsHelper.java | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index f38d0896a5f..b57ab821b4e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -325,7 +325,7 @@ public static List assertShapes( * @param multiLabel indicates whether multidimensional prediction/labels should be treated as * multilabel responses, or flattened into a single label. When true, the values of * variablesToUpdate must have a second dimension equal to the number of labels and - * predictions, and those tensors must not be RaggedTensors. + * predictions per example, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. @@ -429,8 +429,14 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); // The ExtraDims are added so the operands of the tile operations later on are compatible. + + // if multilabel, then shape (1, N, D0) + // else shape (1, ND), + // where Dx == Cx except that D0 == 1 if classIndex != null + // ND is the product of N and all Dx Operand predictionsExtraDim; Operand labelsExtraDim; + if (multiLabel) { predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0)); labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0)); @@ -438,9 +444,22 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1))); labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1))); } + + // the shape of each thresholds tile + // if multilabel, then [T, 1, -1] + // else [T, 1], where T is numThresholds List> threshPretileShape; + + // the tiling multiples for thresholds + // if multilabel, then [1, N, threshLabelTile] + // else [1, ND], where ND is the product of N and all Dx List> threshTiles; + + // the tiling multiples for predictionsExtraDim + // If multilabel, then [T, 1, 1] + // else [T, 1] List> dataTiles; + if (multiLabel) { threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); threshTiles = Arrays.asList(tf.constant(1), numPredictions, threshLabelTile); @@ -456,9 +475,15 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand thresholdsReshaped = tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); Operand threshTilesShape = tf.stack(threshTiles); + + // if multilabel, then shape (T, N, threshLabelTile) + // else shape (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); + Operand stackedTiles = tf.stack(dataTiles); + // if multilabel, then shape (T, N, D0) + // else shape (T, ND) Operand predsTiled = tf.tile(predictionsExtraDim, stackedTiles); // Compare predictions and threshold. From 0710ffe472b8115e1f95a369c42344a5abbe6dd6 Mon Sep 17 00:00:00 2001 From: deansher Date: Sat, 6 Mar 2021 11:47:26 -0500 Subject: [PATCH 60/97] Improved internal docs in MetricsHelper.java --- .../framework/metrics/impl/MetricsHelper.java | 37 ++++++++++++++----- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index b57ab821b4e..b4834e28764 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -360,6 +360,7 @@ public static List updateConfusionMatrixVariables( Operand tPredictions = predictions; Operand tSampleWeight = sampleWeight; + // size of the 0th dimension of thresholds // reshape to scalar for operations later. Operand numThresholds = tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); @@ -415,8 +416,12 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1)); } org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); + + // number of examples Operand numPredictions = tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); + + // number of labels (or predictions) per example Operand numLabels = tf.select( tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), @@ -426,14 +431,24 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.shape.takeLast( predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); + + // If we will treat thresholds as one-dimensional (always true as of this writing), + // then threshLabelTile is the number of labels (or predictions) per sample. Else it is 1. Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); - // The ExtraDims are added so the operands of the tile operations later on are compatible. + ///////// + // Tile data for threshold comparisons, which is a cross product of thresholds and + // predictions/labels. + // + // In the multilabel case, we want a data shape of (T, N, D0). + // else (T, ND). + // where T is numThresholds + // Dx == Cx except that D0 == 1 if classIndex != null + // ND is the product of N and all Dx. + // In these comments, we refer to all indices beyond the threshold index as a "data position". // if multilabel, then shape (1, N, D0) // else shape (1, ND), - // where Dx == Cx except that D0 == 1 if classIndex != null - // ND is the product of N and all Dx Operand predictionsExtraDim; Operand labelsExtraDim; @@ -447,17 +462,19 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), // the shape of each thresholds tile // if multilabel, then [T, 1, -1] - // else [T, 1], where T is numThresholds + // else [T, -1] List> threshPretileShape; // the tiling multiples for thresholds - // if multilabel, then [1, N, threshLabelTile] - // else [1, ND], where ND is the product of N and all Dx + // We want to repeat the thresholds for each data position. + // if multilabel, then [1, N, threshLabelTile]. (threshLabelTile is typically numLabels) + // else [1, ND] List> threshTiles; - // the tiling multiples for predictionsExtraDim + // tiling multiples for predictionsExtraDim and labelsExtraDim + // We want to repeat the predictions and labels for each threshold. // If multilabel, then [T, 1, 1] - // else [T, 1] + // else [T, 1] List> dataTiles; if (multiLabel) { @@ -477,13 +494,13 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand threshTilesShape = tf.stack(threshTiles); // if multilabel, then shape (T, N, threshLabelTile) - // else shape (T, ND) + // else (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); Operand stackedTiles = tf.stack(dataTiles); // if multilabel, then shape (T, N, D0) - // else shape (T, ND) + // else (T, ND) Operand predsTiled = tf.tile(predictionsExtraDim, stackedTiles); // Compare predictions and threshold. From 7e61ba288aa49c768ba5569ad62ea60705052ff6 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 10 Mar 2021 10:34:24 -0500 Subject: [PATCH 61/97] Cleanup of updateConfusionMatrixVariables with variable name changes and reuse of previously declared/assigned variables. --- .../framework/metrics/impl/MetricsHelper.java | 83 ++++--------------- 1 file changed, 16 insertions(+), 67 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index b4834e28764..3fa602b0a74 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -255,7 +255,7 @@ public static List assertShapes( s -> { Long size = dict.get(s); if (size == null) { - size = symbol.getOperand().asOutput().shape().size((int) ll.get()); + size = symbol.getOperand().shape().size((int) ll.get()); dict.put(s, size); } Op assertion = @@ -299,25 +299,18 @@ public static List assertShapes( * * @param tf the TensorFlow Ops * @param variablesToUpdate Map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. If multiLabel is - * false then all shapes are (T), where T is the number of thresholds. If - * multiLabel is true then all shapes are (T, C0), where C0 is the number - * of classes. + * corresponding variables to update as values. * @param varInitializers Map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to initializer the corresponding variables from * variablesToUpdate. * @param labels the labels, will be cast to {@link TBool} - * shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more - * class dimensions, and L1 is a potential extra dimension of size 1 that - * would be squeezed. If multiLabel or if - * labelWeights != null, then Cx must be a single dimension. - * @param predictions the predictions shape (N, Cx, P1?). + * @param predictions the predictions whose values are in the range [0, 1]. * @param thresholds thresholds in `the range [0, 1], or {@link #NEG_INF} (used when * topK is set) - * @param topK Optional, used only if multiLabel, indicates that only the top k - * predictions should be considered. May be null. + * @param topK Optional, indicates that the positive labels should be limited to the top k + * predictions, may be null. * @param classIndex Optional, limits the prediction and labels to the specified class. - * The classIndex is an integer index into the first dimension of Cx. + * The classIndex is and integer representing a specific classification class's input data.. * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as * labels, and must be broadcast to labels (i.e., all dimensions * must be either 1, or the same as the corresponding labels @@ -325,7 +318,7 @@ public static List assertShapes( * @param multiLabel indicates whether multidimensional prediction/labels should be treated as * multilabel responses, or flattened into a single label. When true, the values of * variablesToUpdate must have a second dimension equal to the number of labels and - * predictions per example, and those tensors must not be RaggedTensors. + * predictions, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. @@ -360,12 +353,9 @@ public static List updateConfusionMatrixVariables( Operand tPredictions = predictions; Operand tSampleWeight = sampleWeight; - // size of the 0th dimension of thresholds // reshape to scalar for operations later. Operand numThresholds = tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); - - // true if we will process thresholds as one-dimensional (possibly because we flatten them) Operand oneThresh; if (multiLabel) { oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds)); @@ -405,7 +395,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), throw new IllegalArgumentException( String.format( "Shapes %s and %s are incompatible)", - tPredictions.shape().toString(), tLabels.asOutput().shape().toString())); + tPredictions.shape().toString(), tLabels.shape().toString())); if (topK != null) { tPredictions = filterTopK(tf, tPredictions, topK); @@ -416,12 +406,8 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1)); } org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); - - // number of examples - Operand numPredictions = + Operand numExamples = tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); - - // number of labels (or predictions) per example Operand numLabels = tf.select( tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), @@ -431,27 +417,11 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.shape.takeLast( predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); - - // If we will treat thresholds as one-dimensional (always true as of this writing), - // then threshLabelTile is the number of labels (or predictions) per sample. Else it is 1. Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); - ///////// - // Tile data for threshold comparisons, which is a cross product of thresholds and - // predictions/labels. - // - // In the multilabel case, we want a data shape of (T, N, D0). - // else (T, ND). - // where T is numThresholds - // Dx == Cx except that D0 == 1 if classIndex != null - // ND is the product of N and all Dx. - // In these comments, we refer to all indices beyond the threshold index as a "data position". - - // if multilabel, then shape (1, N, D0) - // else shape (1, ND), + // The ExtraDims are added so the operands of the tile operations later on are compatible. Operand predictionsExtraDim; Operand labelsExtraDim; - if (multiLabel) { predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0)); labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0)); @@ -459,32 +429,17 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1))); labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1))); } - - // the shape of each thresholds tile - // if multilabel, then [T, 1, -1] - // else [T, -1] List> threshPretileShape; - - // the tiling multiples for thresholds - // We want to repeat the thresholds for each data position. - // if multilabel, then [1, N, threshLabelTile]. (threshLabelTile is typically numLabels) - // else [1, ND] List> threshTiles; - - // tiling multiples for predictionsExtraDim and labelsExtraDim - // We want to repeat the predictions and labels for each threshold. - // If multilabel, then [T, 1, 1] - // else [T, 1] List> dataTiles; - if (multiLabel) { threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); - threshTiles = Arrays.asList(tf.constant(1), numPredictions, threshLabelTile); + threshTiles = Arrays.asList(tf.constant(1), numExamples, threshLabelTile); dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1)); } else { threshPretileShape = Arrays.asList(tf.reshape(numThresholds, tf.constant(Shape.scalar())), tf.constant(-1)); - Operand mul = tf.math.mul(numPredictions, numLabels); + Operand mul = tf.math.mul(numExamples, numLabels); threshTiles = Arrays.asList(tf.constant(1), mul); dataTiles = Arrays.asList(numThresholds, tf.constant(1)); } @@ -492,16 +447,10 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand thresholdsReshaped = tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); Operand threshTilesShape = tf.stack(threshTiles); - - // if multilabel, then shape (T, N, threshLabelTile) - // else (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); + Operand dataTilesShape = tf.stack(dataTiles); - Operand stackedTiles = tf.stack(dataTiles); - - // if multilabel, then shape (T, N, D0) - // else (T, ND) - Operand predsTiled = tf.tile(predictionsExtraDim, stackedTiles); + Operand predsTiled = tf.tile(predictionsExtraDim, dataTilesShape); // Compare predictions and threshold. Operand predIsPos = tf.math.greater(predsTiled, threshTiled); @@ -510,8 +459,8 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand weightsTiled; if (tSampleWeight != null) { tSampleWeight = - tf.broadcastTo(cast(tf, tSampleWeight, predictions.type()), tf.shape(tPredictions)); - weightsTiled = tf.tile(tf.reshape(tSampleWeight, tf.stack(threshTiles)), tf.stack(dataTiles)); + tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); + weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape); } else { weightsTiled = null; } From efd7d43a4c5c65d32fca67372833acb961e5a496 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 10 Mar 2021 13:49:30 -0500 Subject: [PATCH 62/97] Reformat code --- .../framework/metrics/SparseCategoricalAccuracy.java | 1 - .../framework/metrics/impl/MetricsHelper.java | 12 ++++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 0d18c1e2dcb..7bfa7fd6ee9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -63,7 +63,6 @@ * 0.3 *
* - * * @param The data type for the metric result */ public class SparseCategoricalAccuracy extends MeanMetricWrapper diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 3fa602b0a74..5f4735c818c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; -import org.tensorflow.Session; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; @@ -275,7 +274,6 @@ public static List assertShapes( return updateOperations; } - /** * Returns an op to update the given confusion matrix variables. * @@ -309,8 +307,8 @@ public static List assertShapes( * topK is set) * @param topK Optional, indicates that the positive labels should be limited to the top k * predictions, may be null. - * @param classIndex Optional, limits the prediction and labels to the specified class. - * The classIndex is and integer representing a specific classification class's input data.. + * @param classIndex Optional, limits the prediction and labels to the specified class. The + * classIndex is and integer representing a specific classification class's input data.. * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as * labels, and must be broadcast to labels (i.e., all dimensions * must be either 1, or the same as the corresponding labels @@ -413,7 +411,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), tf.constant(1), tf.reduceProd( - // take all but the first dimension + // take all but the first dimension tf.shape.takeLast( predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); @@ -458,8 +456,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles)); Operand weightsTiled; if (tSampleWeight != null) { - tSampleWeight = - tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); + tSampleWeight = tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape); } else { weightsTiled = null; @@ -521,7 +518,6 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), return controlOps; } - /** * Creates an Operand that adds the values by taking the logical and of labels and predictions to * the specified confusion matrix variable. From 5e907df4ef521218c298015d056242044946daa2 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 15:48:53 -0500 Subject: [PATCH 63/97] Fix JavaDoc for enumerations --- .../framework/metrics/AUCSummationMethod.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java index 60687dd9005..3887f687eea 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java @@ -15,18 +15,18 @@ package org.tensorflow.framework.metrics; /** - * Specifies the Riemann summation method used. {@link #INTERPOLATION} (default) applies mid-point - * summation scheme for ROC. For PR-AUC, interpolates (true/false) positives but not the ratio that - * is precision (see Davis & Goadrich 2006 for details); {@link #MINORING} applies left - * summation for increasing intervals and right summation for decreasing intervals; {@link - * #MAJORING} does the opposite. + * Specifies the Riemann summation method used. * * @see Davis & Goadrich. 2006 * @see Riemann summation method */ public enum AUCSummationMethod { + /** Apply mid-point summation scheme for {@link AUCCurve#ROC}. For {@link AUCCurve#PR}, interpolates (true/false) positives but not the ratio that + * is precision */ INTERPOLATION, + /** Apply right summation for increasing intervals and left summation for decreasing intervals */ MAJORING, + /** Apply left summation for increasing intervals and right summation for decreasing intervals */ MINORING; /** From 8856c9cdb1e802c53646362c84153d4e56338fbd Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 15:49:20 -0500 Subject: [PATCH 64/97] Fix JavaDoc to emphasize that this does not inherit from Tensor. --- .../org/tensorflow/framework/utils/SparseTensor.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java index 9dee070eea9..81d658ff3a5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java @@ -15,14 +15,19 @@ package org.tensorflow.framework.utils; import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.op.SparseOps; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; /** - * This is a helper class that represents a sparse tensor who's attributes may be passed to - * {@link org.tensorflow.op.Ops#sparse} methods. + * This is a helper class that represents a sparse tensor who's attributes may be passed to {@link + * SparseOps} methods. * - * @param the type of the SparseTensor + *

This class does not inherit from {@link Tensor}, but is merely a place to accumulate the + * properties that are needed for the {@link SparseOps} methods. + * + * @param the type of the SparseTensor's values. */ public class SparseTensor { private final Operand indices; From b533b2ef16bc10688b9b1161a935cb7de27dd898 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 15:50:06 -0500 Subject: [PATCH 65/97] Fix 'import *' --- .../framework/metrics/impl/MetricsHelper.java | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 5f4735c818c..81459515657 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -22,15 +22,31 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; + +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Rank; import org.tensorflow.op.core.Stack; -import org.tensorflow.op.core.*; +import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; import org.tensorflow.op.nn.TopK; -import org.tensorflow.types.*; + + +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import java.util.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicLong; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; @@ -81,7 +97,7 @@ public static Op assertBroadcastable( && !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") - .withControlDependencies(Collections.EMPTY_LIST) + .withControlDependencies(java.util.Collections.EMPTY_LIST) .noOp(); } if (weightsRankStatic != valuesRankStatic) { @@ -742,7 +758,7 @@ public static Operand confusionMatrix( Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls); tLabels = tfc.identity(tLabels); - tfc = tf.withSubScope("confusionMatrixPredicitons").withControlDependencies(predictionControls); + tfc = tf.withSubScope("confusionMatrixPredictions").withControlDependencies(predictionControls); tPredictions = tfc.identity(tPredictions); Operand shape = tf.stack(Arrays.asList(numClasses, numClasses)); @@ -752,6 +768,7 @@ public static Operand confusionMatrix( SparseTensor cmSparse = new SparseTensor<>(indices, values, shape); Operand zeroMatrix = tf.zeros(shape, type); + return tf.sparse.sparseTensorDenseAdd( cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix); } From 21029a7e968260ff2e0d105465761b67c16c2bc5 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 15:51:14 -0500 Subject: [PATCH 66/97] Fix casts --- .../tensorflow/framework/metrics/Metrics.java | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index e4cc9c3aa3d..bcf2ea4c880 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -15,13 +15,14 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; -import org.tensorflow.framework.utils.CastHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** Helper class with built-in metrics functions. */ public class Metrics { @@ -49,8 +50,8 @@ public class Metrics { */ public static Operand topKCategoricalAccuracy( Ops tf, Operand labels, Operand predictions, long k) { - Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); - return CastHelper.cast( + Operand fPredictions = cast(tf, predictions, TFloat32.class); + return cast( tf, tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); @@ -81,15 +82,13 @@ public static Operand topKCategoricalAccuracy( @SuppressWarnings("unchecked") public static Operand sparseTopKCategoricalAccuracy( Ops tf, Operand labels, Operand predictions, int k) { - Operand tLabels; - if (labels.type() != predictions.type()) - tLabels = CastHelper.cast(tf, labels, predictions.type()); - else tLabels = (Operand) labels; + Operand tLabels = cast(tf, labels, predictions.type()); + int predictionsRank = predictions.shape().numDimensions(); int labelsRank = tLabels.shape().numDimensions(); - Operand castPredictions = CastHelper.cast(tf, predictions, TFloat32.class); + Operand castPredictions = cast(tf, predictions, TFloat32.class); if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { if (predictionsRank > 2) { castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); @@ -98,9 +97,9 @@ public static Operand sparseTopKCatego tLabels = tf.shape.flatten(tLabels); } } - return CastHelper.cast( + return cast( tf, - tf.nn.inTopK(castPredictions, CastHelper.cast(tf, tLabels, TInt32.class), tf.constant(k)), + tf.nn.inTopK(castPredictions, cast(tf, tLabels, TInt32.class), tf.constant(k)), predictions.type()); } } From e154453230da7546d13033e73a908d0b894b6d26 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 19:03:25 -0500 Subject: [PATCH 67/97] Reformat code --- .../org/tensorflow/framework/metrics/AUC.java | 36 +++++----- .../framework/metrics/impl/MetricsHelper.java | 68 +++++++++++++++---- 2 files changed, 72 insertions(+), 32 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index cae67dbd4f0..9fbb3a3ad09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -107,9 +107,7 @@ public class AUC extends Metric { private final Map> initializers = new HashMap<>(); private final Class type; - /** - * The size of the label dimension. - */ + /** The size of the label dimension. */ private Integer numLabels; private Operand labelWeights; @@ -117,7 +115,7 @@ public class AUC extends Metric { /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * - * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single * class dimension within each example. */ private Variable truePositives; @@ -125,7 +123,7 @@ public class AUC extends Metric { /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * - * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single * class dimension within each example. */ private Variable falsePositives; @@ -133,7 +131,7 @@ public class AUC extends Metric { /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * - * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single * class dimension within each example. */ private Variable trueNegatives; @@ -141,7 +139,7 @@ public class AUC extends Metric { /** * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. * - * If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single * class dimension within each example. */ private Variable falseNegatives; @@ -549,8 +547,8 @@ public AUC( * * @param tf The TensorFlow Ops * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. - * @param numThresholds the number of thresholds to use when discretizing the roc curve. - * This includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2. + * @param numThresholds the number of thresholds to use when discretizing the roc curve. This + * includes the bracketing 0 and 1 thresholds, so the value must be &GE; 2. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link * AUCCurve#PR} for the Precision-Recall-curve. * @param summationMethod Specifies the Riemann summation method used @@ -563,10 +561,10 @@ public AUC( * if the data should be flattened into a single label before AUC computation. In the latter * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an * individual data point. Should be set to false for multi-class data. - * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When - * multiLabel is true, the weights are applied to the individual label AUCs when - * they are averaged to produce the multi-label AUC. When it's false, they are used to weight - * the individual label predictions in computing the confusion matrix on the flattened data. + * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When + * multiLabel is true, the weights are applied to the individual label AUCs when they + * are averaged to produce the multi-label AUC. When it's false, they are used to weight the + * individual label predictions in computing the confusion matrix on the flattened data. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. @@ -714,15 +712,13 @@ private Map> build(Shape shape) { /** * Creates a List of Operations to update the metric state based on labels and predictions. * - * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more - * class dimensions, and L1 is a potential extra dimension of size 1 that - * would be squeezed. Will be cast to T. If - * {@link #multiLabel} or if {@link #labelWeights} != null, - * then Cx must be a single dimension. + * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class + * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be + * cast to T. If {@link #multiLabel} or if {@link #labelWeights} != null + * , then Cx must be a single dimension. * @param predictions the predictions shape (N, Cx, P1?). Will be cast to T. * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to - * T. - * + * T. * @return a List of Operations to update the metric state */ @Override diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 81459515657..cf1755cad56 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -31,7 +31,6 @@ import org.tensorflow.op.math.Mean; import org.tensorflow.op.nn.TopK; - import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -40,7 +39,6 @@ import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; - import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -313,18 +311,23 @@ public static List assertShapes( * * @param tf the TensorFlow Ops * @param variablesToUpdate Map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. + * corresponding variables to update as values. If multiLabel is false then all + * shapes are (T), where T is the number of thresholds. If multiLabel is true + * then all shapes are (T, C0), where C0 is the number of classes. * @param varInitializers Map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to initializer the corresponding variables from * variablesToUpdate. - * @param labels the labels, will be cast to {@link TBool} - * @param predictions the predictions whose values are in the range [0, 1]. + * @param labels the labels, will be cast to {@link TBool} shape (N, Cx, L1?) where N is the + * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra + * dimension of size 1 that would be squeezed. If multiLabel or if + * labelWeights != null, then Cx must be a single dimension. + * @param predictions the predictions shape (N, Cx, P1?). * @param thresholds thresholds in `the range [0, 1], or {@link #NEG_INF} (used when * topK is set) - * @param topK Optional, indicates that the positive labels should be limited to the top k - * predictions, may be null. + * @param topK Optional, used only if multiLabel, indicates that only the top k + * predictions should be considered. May be null. * @param classIndex Optional, limits the prediction and labels to the specified class. The - * classIndex is and integer representing a specific classification class's input data.. + * classIndex is an integer index into the first dimension of Cx. * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as * labels, and must be broadcast to labels (i.e., all dimensions * must be either 1, or the same as the corresponding labels @@ -332,7 +335,7 @@ public static List assertShapes( * @param multiLabel indicates whether multidimensional prediction/labels should be treated as * multilabel responses, or flattened into a single label. When true, the values of * variablesToUpdate must have a second dimension equal to the number of labels and - * predictions, and those tensors must not be RaggedTensors. + * predictions per example, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. @@ -367,9 +370,12 @@ public static List updateConfusionMatrixVariables( Operand tPredictions = predictions; Operand tSampleWeight = sampleWeight; + // size of the 0th dimension of thresholds // reshape to scalar for operations later. Operand numThresholds = tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); + + // true if we will process thresholds as one-dimensional (possibly because we flatten them) Operand oneThresh; if (multiLabel) { oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds)); @@ -420,8 +426,11 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1)); } org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); + Operand numExamples = tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); + + // number of labels (or predictions) per example Operand numLabels = tf.select( tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), @@ -431,11 +440,27 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.shape.takeLast( predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); + + // If we will treat thresholds as one-dimensional (always true as of this writing), + // then threshLabelTile is the number of labels (or predictions) per sample. Else it is 1. Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); - // The ExtraDims are added so the operands of the tile operations later on are compatible. + ///////// + // Tile data for threshold comparisons, which is a cross product of thresholds and + // predictions/labels. + // + // In the multilabel case, we want a data shape of (T, N, D0). + // else (T, ND). + // where T is numThresholds + // Dx == Cx except that D0 == 1 if classIndex != null + // ND is the product of N and all Dx. + // In these comments, we refer to all indices beyond the threshold index as a "data position". + + // if multilabel, then shape (1, N, D0) + // else shape (1, ND), Operand predictionsExtraDim; Operand labelsExtraDim; + if (multiLabel) { predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0)); labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0)); @@ -443,9 +468,24 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1))); labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1))); } + + // the shape of each thresholds tile + // if multilabel, then [T, 1, -1] + // else [T, -1] List> threshPretileShape; + + // the tiling multiples for thresholds + // We want to repeat the thresholds for each data position. + // if multilabel, then [1, N, threshLabelTile]. (threshLabelTile is typically numLabels) + // else [1, ND] List> threshTiles; + + // tiling multiples for predictionsExtraDim and labelsExtraDim + // We want to repeat the predictions and labels for each threshold. + // If multilabel, then [T, 1, 1] + // else [T, 1] List> dataTiles; + if (multiLabel) { threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); threshTiles = Arrays.asList(tf.constant(1), numExamples, threshLabelTile); @@ -461,9 +501,14 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand thresholdsReshaped = tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); Operand threshTilesShape = tf.stack(threshTiles); + + // if multilabel, then shape (T, N, threshLabelTile) + // else (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); - Operand dataTilesShape = tf.stack(dataTiles); + // if multilabel, then shape (T, N, D0) + // else (T, ND) + Operand dataTilesShape = tf.stack(dataTiles); Operand predsTiled = tf.tile(predictionsExtraDim, dataTilesShape); // Compare predictions and threshold. @@ -768,7 +813,6 @@ public static Operand confusionMatrix( SparseTensor cmSparse = new SparseTensor<>(indices, values, shape); Operand zeroMatrix = tf.zeros(shape, type); - return tf.sparse.sparseTensorDenseAdd( cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix); } From eb0a7e609f90aa61b45009b88995ce60beb8f7f6 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 19:04:27 -0500 Subject: [PATCH 68/97] Reformat code --- .../tensorflow/framework/metrics/AUCTest.java | 6 +++-- .../framework/metrics/BinaryAccuracyTest.java | 3 +-- .../metrics/CategoricalAccuracyTest.java | 9 +++---- .../framework/metrics/PrecisionTest.java | 16 ++++--------- .../framework/metrics/RecallTest.java | 24 +++++++------------ 5 files changed, 21 insertions(+), 37 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java index 88825b5f32e..857a5c93f7a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -23,7 +23,9 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.tensorflow.framework.utils.CastHelper.cast; public class AUCTest { @@ -199,7 +201,7 @@ public void testWeightedRocMinoring() { session.run(update); Operand result = instance.result(); - float expectedResult = ( 0.5714285f + 0f * 0f); + float expectedResult = (0.5714285f + 0f * 0f); session.evaluate(expectedResult, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java index e8d8350dcdc..d203815f4ab 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java @@ -155,8 +155,7 @@ public void testVariableState() { public void testBinaryAccuracyAThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - BinaryAccuracy instance = - new BinaryAccuracy<>(tf, 0.7f, 1001L, TFloat32.class); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 0.7f, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = {1, 1, 0, 0}; float[] predArray = {0.9f, 0.6f, 0.4f, 0.8f}; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java index 83990cbaebb..aea2e4e0d6e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java @@ -31,8 +31,7 @@ public class CategoricalAccuracyTest { public void testCorrect() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalAccuracy instance = - new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { 0, 0, 1, @@ -60,8 +59,7 @@ public void testCorrect() { public void testSampleWeight() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalAccuracy instance = - new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { 0, 0, 1, @@ -92,8 +90,7 @@ public void testSampleWeight() { public void testVariableState() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - CategoricalAccuracy instance = - new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); session.run(instance.resetStates()); int[] trueArray = { 0, 0, 1, diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java index 148ca520d3f..cfe5b483e2b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java @@ -16,7 +16,6 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.framework.metrics.impl.MetricsHelper; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; @@ -203,8 +202,7 @@ public void testUnweightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK to 3 - Precision instance = - new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); @@ -220,8 +218,7 @@ public void testWeightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK to 3 - Precision instance = - new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[] {0.2f, 0.1f, 0.4f, 0f, 0.2f}); @@ -249,8 +246,7 @@ public void testUnweightedClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set classId to 2 - Precision instance = - new Precision<>(tf, null, null, 2, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, null, null, 2, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); @@ -290,8 +286,7 @@ public void testUnweightedTopKAndClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK and classId to 2 - Precision instance = - new Precision<>(tf, null, 2, 2, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, null, 2, 2, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0f, 0.2f}}); @@ -321,8 +316,7 @@ public void testUnweightedTopKAndThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); // set topK to 2 - Precision instance = - new Precision<>(tf, 0.7f, 2, null, 1001L, TFloat32.class); + Precision instance = new Precision<>(tf, 0.7f, 2, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java index b9d067a6ed2..bd9fbb1ab66 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java @@ -150,8 +150,7 @@ public void testDivByZero() { public void testUnweightedWithThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{1, 0, 0.6f, 0}}); @@ -169,8 +168,7 @@ public void testUnweightedWithThreshold() { public void testWeightedWithThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); @@ -192,8 +190,7 @@ public void testWeightedWithThreshold() { public void testMultipleUpdates() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); @@ -215,8 +212,7 @@ public void testMultipleUpdates() { public void testUnweightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand labels = tf.constant(new float[][] {{0f, 1f, 1f, 0f, 0f}}); @@ -233,8 +229,7 @@ public void testUnweightedTopK() { public void testWeightedTopK() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 1}}); @@ -262,8 +257,7 @@ public void testWeightedTopK() { public void testUnweightedClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, null, null, 2, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, null, null, 2, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); @@ -296,8 +290,7 @@ public void testUnweightedClassId() { public void testUnweightedTopKAndClassId() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, null, 2, 2, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, null, 2, 2, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0, 0.2f}}); @@ -324,8 +317,7 @@ public void testUnweightedTopKAndClassId() { public void testUnweightedTopKAndThreshold() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Recall instance = - new Recall<>(tf, null, 0.7f, 2, null, 1001L, TFloat32.class); + Recall instance = new Recall<>(tf, null, 0.7f, 2, null, 1001L, TFloat32.class); session.run(instance.resetStates()); Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); From a29f8e926c5bc67beb0691a486ae73a7fbe171ae Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 11 Mar 2021 19:24:24 -0500 Subject: [PATCH 69/97] Fix javadoc change >= to ≥ --- .../main/java/org/tensorflow/framework/metrics/AUC.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 9fbb3a3ad09..420b1567496 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -26,7 +26,12 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -548,7 +553,7 @@ public AUC( * @param tf The TensorFlow Ops * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. * @param numThresholds the number of thresholds to use when discretizing the roc curve. This - * includes the bracketing 0 and 1 thresholds, so the value must be &GE; 2. + * includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link * AUCCurve#PR} for the Precision-Recall-curve. * @param summationMethod Specifies the Riemann summation method used From d47c3b8c386ec707193709a7d3f6eaa1b93cd298 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 18 Mar 2021 12:09:54 -0400 Subject: [PATCH 70/97] Fix spelling in JavaDoc --- .../java/org/tensorflow/framework/metrics/AUC.java | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 420b1567496..72a1f022b41 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -63,7 +63,7 @@ *

Usage:
* *

- * AUC m = new  org.tensorflow.framework.metrcis.AUC( tf, 3);
+ * AUC m = new  org.tensorflow.framework.metrics.AUC( tf, 3);
  * m.updateState( tf.constant(new float[] {0, 0, 1,1}),
  *          tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
  *
@@ -603,7 +603,7 @@ public AUC(
         if (t < 0.0f || t > 1.0f) {
           throw new IllegalArgumentException(
               String.format(
-                  "Threshold values must be in [0, 1]. Invalid values: %s",
+                  "Threshold values must be in range [0, 1], inclusive. Invalid values: %s",
                   Arrays.toString(thresholds)));
         }
       }
@@ -621,12 +621,7 @@ public AUC(
         thresholds[i] = (i + 1) * 1.0f / (this.numThresholds - 1);
       }
     }
-    // Add an endpoint "threshold" below zero and above one for either
-    // threshold method to account for floating point imprecision.
-    if (thresholds.length != this.numThresholds - 2) {
-      throw new IllegalArgumentException(
-          "Thresholds length must contain numThresholds - 2 entries");
-    }
+
     // Add an endpoint "threshold" below zero and above one for either
     // threshold method to account for floating point imprecisions.
     this.thresholds = new float[this.numThresholds];
@@ -754,7 +749,7 @@ public List updateStateList(
         symbols.add(new SymbolicShape<>(falseNegatives, "T", "L"));
       }
       if (getLabelWeights() != null) {
-        symbols.add(new SymbolicShape<>(getLabelWeights(), "L", ""));
+        symbols.add(new SymbolicShape<>(getLabelWeights(), "L"));
       }
       updateOperations.addAll(
           MetricsHelper.assertShapes(tf, symbols, "Number of labels is not consistent."));

From 7f46673aa477af71f1c4a07d5dacf1005615f270 Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Thu, 18 Mar 2021 12:11:33 -0400
Subject: [PATCH 71/97] Change assertShapes to use runtime sizes as Operands
 rather than use primitive long.

---
 .../tensorflow/framework/metrics/impl/MetricsHelper.java | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
index cf1755cad56..07a630560fe 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
@@ -255,7 +255,7 @@ public static List assertShapes(
           updateOperations.add(assertion);
         });
 
-    Map dict = new HashMap<>();
+    Map> dict = new HashMap<>();
 
     // check that each operand's dimension size equals the corresponding symbolic shape's dimensions
     // size
@@ -266,9 +266,10 @@ public static List assertShapes(
               .getSymbols()
               .forEach(
                   s -> {
-                    Long size = dict.get(s);
+                    Operand size = dict.get(s);
                     if (size == null) {
-                      size = symbol.getOperand().shape().size((int) ll.get());
+                      // save size for later checks
+                      size = tf.shape.size( symbol.getOperand(), tf.constant(ll.get()), TInt64.class);
                       dict.put(s, size);
                     }
                     Op assertion =
@@ -279,7 +280,7 @@ public static List assertShapes(
                                         symbol.getOperand(),
                                         tf.constant(ll.getAndIncrement()),
                                         TInt64.class),
-                                    tf.constant(size)),
+                                        size),
                                 Collections.singletonList(tf.constant(message)));
                     updateOperations.add(assertion);
                   });

From 41fde6571a832e529d0ed4f7604e377fe2b39982 Mon Sep 17 00:00:00 2001
From: Jim Clarke 
Date: Sun, 21 Mar 2021 08:09:04 -0400
Subject: [PATCH 72/97] Replace calls to tf.slice with private method slice to
 clean up code. Added methods isPositive and posivite to clarify what was
 being done

---
 .../org/tensorflow/framework/metrics/AUC.java | 158 ++++++++++++------
 1 file changed, 103 insertions(+), 55 deletions(-)

diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
index 72a1f022b41..b09154b46da 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java
@@ -24,6 +24,7 @@
 import org.tensorflow.op.Ops;
 import org.tensorflow.op.core.Assign;
 import org.tensorflow.op.core.Variable;
+import org.tensorflow.types.TBool;
 import org.tensorflow.types.family.TNumber;
 
 import java.util.ArrayList;
@@ -780,65 +781,108 @@ public List updateStateList(
     return updateOperations;
   }
 
+  /**
+   * Gets the input with all positive numbers. Negative numbers are set to 0.
+   *
+   * @param input the input
+   * @return the input with all positive numbers.
+   */
+  private Operand positive(Operand input) {
+    return getTF().math.maximum(input, cast(getTF(), getTF().constant(0), input.type()));
+  }
+
+  /**
+   * Gets an operand that determines whether the input consists of each value is greater than 0.
+   *
+   * @param input the input
+   * @return an operand that determines whether the input consists of all values greater than 0.
+   */
+  private Operand isPositive(Operand input) {
+    return getTF().math.greater(input, cast(getTF(), getTF().constant(0), input.type()));
+  }
+
+  /**
+   * Extracts a slice from the input.
+   *
+   * @param input the input
+   * @param begin the beginning location of the slice
+   * @param size the size of the slice
+   * @return the slice
+   */
+  private Operand slice(Operand input, int begin, int size) {
+    return getTF()
+        .slice(input, getTF().constant(new int[] {begin}), getTF().constant(new int[] {size}));
+  }
+
   /**
    * Interpolation formula inspired by section 4 of Davis & Goadrich 2006.
    *
+   * 

Note here we derive & use a closed formula not present in the paper as follows: + *

+   *     Precision = TP / (TP + FP) = TP / P
+   * 
+ *

Modeling all of TP (true positive), FP (false positive) and their sum + * P = TP + FP (predicted positive) as varying linearly within each interval + * [A, B] between successive thresholds, we get

+ *
+   *     Precision slope = dTP / dP
+   *                     = (TP_B - TP_A) / (P_B - P_A)
+   *                     = (TP - TP_A) / (P - P_A)
+   *     Precision = (TP_A + slope * (P - P_A)) / P
+   * 
+ *

The area within the interval is (slope / total_pos_weight) times + *

+   *       int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
+   *       int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
+   * 
+ * where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + *
+   *       int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
+   * 
+ * Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + *
+   *       slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
+   * 
+ * where dTP == TP_B - TP_A. + * Note that when P_A == 0 the above calculation simplifies into + *
+   *       int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
+   * 
+ * which is really equivalent to imputing constant precision throughout the + * first bucket having >0 true positives. + * * @return an approximation of the area under the P-R curve. + * @see The Relationship Between Precision-Recall and ROC Curves - Davis & Goadrich 2006 */ private Operand interpolatePRAuc() { // truePositives[:self.numThresholds - 1] Ops tf = getTF(); - Operand tp0 = - tf.slice( - truePositives, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})); + Operand tp0 = slice(truePositives, 0, getNumThresholds() - 1); // truePositives[1:] - Operand tp1 = - tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand tp1 = slice(truePositives, 1, -1); Operand dTP = tf.math.sub(tp0, tp1); Operand p = tf.math.add(truePositives, falsePositives); - Operand dP = - tf.math.sub( - tf.slice( - p, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})), - tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))); + Operand dP = tf.math.sub(slice(p, 0, getNumThresholds() - 1), slice(p, 1, -1)); Operand precisionSlope = - tf.math.divNoNan(dTP, tf.math.maximum(dP, tf.dtypes.cast(tf.constant(0), dP.type()))); + tf.math.divNoNan(dTP, positive(dP)); Operand intercept = - tf.math.sub( - tf.slice(truePositives, tf.constant(new int[] {1}), tf.constant(new int[] {-1})), - tf.math.mul( - precisionSlope, - tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})))); + tf.math.sub(slice(truePositives, 1, -1), tf.math.mul(precisionSlope, slice(p, 1, -1))); Operand safePRatio = tf.select( tf.math.logicalAnd( - tf.math.greater( - tf.slice( - p, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})), - tf.dtypes.cast(tf.constant(0), p.type())), - tf.math.greater( - tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})), - tf.dtypes.cast(tf.constant(0), p.type()))), + isPositive(slice(p, 0, getNumThresholds() - 1)), isPositive(slice(p, 1, -1))), tf.math.divNoNan( - tf.slice( - p, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})), - tf.math.maximum( - tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})), - tf.dtypes.cast(tf.constant(0), p.type()))), - tf.onesLike(tf.slice(p, tf.constant(new int[] {1}), tf.constant(new int[] {-1})))); + slice(p, 0, getNumThresholds() - 1), + positive(slice(p, 1, -1))), + tf.onesLike(slice(p, 1, -1))); - Operand fn1 = - tf.slice(falseNegatives, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand fn1 = slice(falseNegatives, 1, -1); Operand aucTotalPos = tf.math.mul( @@ -847,14 +891,15 @@ private Operand interpolatePRAuc() { Operand prAucIncrement = tf.math.divNoNan( aucTotalPos, - tf.math.maximum( - tf.math.add(tp1, fn1), tf.dtypes.cast(tf.constant(0), truePositives.type()))); + positive(tf.math.add(tp1, fn1))); if (isMultiLabel()) { Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0)); if (getLabelWeights() == null) { + //Evenly weighted average of the label AUCs. return MetricsHelper.mean(tf, byLabelAuc); } else { + // Weighted average of the label AUCs. return tf.math.divNoNan( tf.reduceSum(tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, byLabelAuc)), tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))); @@ -877,22 +922,27 @@ public Operand result() { Operand y; Operand recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); - if (getCurve() == AUCCurve.ROC) { - x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives)); - y = recall; - } else { // AUCCurve.PR - y = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); - x = recall; + switch (getCurve()) { + case ROC: + x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives)); + y = recall; + break; + case PR: + y = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + x = recall; + break; + default: + throw new IllegalArgumentException("Unexpected AUCCurve value: " + getCurve()); } // Find the rectangle heights based on `summationMethod`. // y[:self.numThresholds - 1] - Operand ySlice1 = - tf.slice(y, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); + Operand ySlice1 = slice(y, 0, getNumThresholds() - 1); // y[1:] - Operand ySlice2 = tf.slice(y, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand ySlice2 = slice(y, 1, -1); + - Operand heights = null; + Operand heights; switch (getSummationMethod()) { case INTERPOLATION: heights = tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type())); @@ -903,17 +953,16 @@ public Operand result() { case MAJORING: heights = tf.math.maximum(ySlice1, ySlice2); break; + default: + throw new IllegalArgumentException("Unexpected AUCSummationMethod value: " + getSummationMethod()); } if (isMultiLabel()) { Operand riemannTerms = tf.math.mul( tf.math.sub( - tf.slice( - x, - tf.constant(new int[] {0}), - tf.constant(new int[] {getNumThresholds() - 1})), - tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1}))), + slice(x, 0, getNumThresholds() - 1), + slice(x, 1, -1)), heights); Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0)); @@ -928,9 +977,8 @@ public Operand result() { } } else { - Operand slice1 = - tf.slice(x, tf.constant(new int[] {0}), tf.constant(new int[] {getNumThresholds() - 1})); - Operand slice2 = tf.slice(x, tf.constant(new int[] {1}), tf.constant(new int[] {-1})); + Operand slice1 = slice(x,0, getNumThresholds() - 1); + Operand slice2 = slice(x, 1, -1); Operand sub = tf.math.sub(slice1, slice2); Operand operand = tf.math.mul(sub, heights); return tf.reduceSum(operand, allAxes(tf, operand)); From 3c7e3a7f34b20ba908dcdc2a3bfcf4e56fdc6824 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 08:09:48 -0400 Subject: [PATCH 73/97] Fix Javdoc, remove spurious y_pred. --- .../org/tensorflow/framework/metrics/CategoricalAccuracy.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index c0635746d4d..c3780cc6de2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -27,7 +27,7 @@ /** * Metric that calculates how often predictions matches one-hot labels. * - *

You can provide logits of classes as predictionsy_pred, since argmax + *

You can provide logits of classes as predictions, since argmax * of logits and probabilities are same. * *

This metric creates two local variables, total and count that are From 4a114be82c835803424b02d14ca434e281bd3294 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 08:10:41 -0400 Subject: [PATCH 74/97] remove spurious cast --- .../org/tensorflow/framework/metrics/impl/MetricsHelper.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 07a630560fe..a754c93be46 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -604,8 +604,7 @@ private static Operand weightedAssignAdd( Operand labelAndPred = cast(tf, tf.math.logicalAnd(labels, predictions), type); if (weights != null) { - Operand lWeights = cast(tf, weights, type); - labelAndPred = tf.math.mul(labelAndPred, lWeights); + labelAndPred = tf.math.mul(labelAndPred, weights); } Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); Operand assignAdd; From da72efd64f183fee85686a0104b3ce76d397556b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 08:11:20 -0400 Subject: [PATCH 75/97] correct comments for enums --- .../framework/metrics/impl/ConfusionMatrixEnum.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java index 281aa2072d0..bf3ade53d73 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java @@ -18,9 +18,9 @@ public enum ConfusionMatrixEnum { /** These are cases in which the prediction is true, and reality is true. */ TRUE_POSITIVES("tp"), - /** These are cases in which the prediction is false, and reality is true. */ - FALSE_POSITIVES("fp"), /** These are cases in which the prediction is true, and reality is false. */ + FALSE_POSITIVES("fp"), + /** These are cases in which the prediction is false, and reality is true. */ TRUE_NEGATIVES("tn"), /** These are cases in which the prediction is false, and reality is false. */ FALSE_NEGATIVES("fn"); From fb1ab3a97a3466b92a8de973db116d41895f1a9b Mon Sep 17 00:00:00 2001 From: deansher Date: Thu, 18 Mar 2021 07:53:34 -0400 Subject: [PATCH 76/97] Revised and improved internal docs in MetricsHelper.java --- .../framework/metrics/impl/MetricsHelper.java | 122 +++++++++++------- 1 file changed, 75 insertions(+), 47 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index a754c93be46..5e29b8a1002 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -305,45 +305,47 @@ public static List assertShapes( * will repeat the same for every threshold. * *

For estimation of these metrics over a stream of data, the function creates an `update_op` - * operation that updates the given variables. + * operation that updates the given variables.

* - *

If sampleWeight is null, weights default to 1. Use weights of 0 to - * mask values. + *

labels, predictions, and sampleWeight tensors are + * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. + * sampleWeight is then broadcast to the shape of predictions.

* * @param tf the TensorFlow Ops - * @param variablesToUpdate Map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. If multiLabel is false then all - * shapes are (T), where T is the number of thresholds. If multiLabel is true - * then all shapes are (T, C0), where C0 is the number of classes. - * @param varInitializers Map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding initializer Operands to initializer the corresponding variables from - * variablesToUpdate. - * @param labels the labels, will be cast to {@link TBool} shape (N, Cx, L1?) where N is the + * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding variables to update as values. If multiLabel, then the + * variable shapes are (T, D), where T is the number of thresholds and D is the number of + * classes (after slicing by classIndex, if provided). + * If multiLabels, then the variable shapes are (T). + * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding initializer Operands to for variablesToUpdate. + * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra - * dimension of size 1 that would be squeezed. If multiLabel or if - * labelWeights != null, then Cx must be a single dimension. - * @param predictions the predictions shape (N, Cx, P1?). - * @param thresholds thresholds in `the range [0, 1], or {@link #NEG_INF} (used when - * topK is set) - * @param topK Optional, used only if multiLabel, indicates that only the top k - * predictions should be considered. May be null. - * @param classIndex Optional, limits the prediction and labels to the specified class. The - * classIndex is an integer index into the first dimension of Cx. - * @param sampleWeight Optional Tensor whose rank is either 0, or the same rank as - * labels, and must be broadcast to labels (i.e., all dimensions - * must be either 1, or the same as the corresponding labels - * dimension). + * dimension of size 1 that would be squeezed. + * @param predictions the predictions shape (N, Cx, P1?) + * @param thresholds thresholds in the range [0, 1], or {@link #NEG_INF} is used + * when topK is set + * @param topK optional, indicates that only the top k predictions should be considered. + * Applied before possibly slicing by classIndex. + * @param classIndex optional, limits the prediction and labels to the specified class. + * This is an integer index into the first dimension of Cx. + * @param sampleWeight optional Tensor that is aligned with labels and predictions + * as explained above. Use weights of 0 to mask values. * @param multiLabel indicates whether multidimensional prediction/labels should be treated as * multilabel responses, or flattened into a single label. When true, the values of * variablesToUpdate must have a second dimension equal to the number of labels and * predictions per example, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES - * without explicit multilabel handling (i.e. when the data is to be flattened). May be null. + * without explicit multilabel handling (i.e. when the data is to be flattened). + * Must have shape (Dx), which is the same as (Cx) referenced above, except that if + * classIndex is provided, then the final dimension of Dx is 1. These weights + * will be broadcast across the 0th dimension (the examples dimension) of + * predictions. May be null. Must be null if multiLabel. * @param the data type for the variables * @throws IllegalArgumentException If predictions and labels have * mismatched shapes, or if sampleWeight is not nulland its shape - * doesn't match predictions + * doesn't match predictions, or if multiLabel && labelWeights != null. * @return an op to update the given confusion matrix variables. */ @SuppressWarnings({"unchecked", "rawtypes"}) @@ -371,12 +373,25 @@ public static List updateConfusionMatrixVariables( Operand tPredictions = predictions; Operand tSampleWeight = sampleWeight; + // We will tile data for threshold comparisons. We want a cross product of thresholds and + // predictions/labels: + // In the multilabel case, we want a data shape of (T, N, D). + // else (T, ND). + // where + // T is numThresholds (the size of the 0th dimension of thresholds) + // N is the number of examples (the 0th dimension of labels and predictions) + // Dx == Cx except that if classIndex != null, + // then the last dimension of Dx is size 1 + // D is the product of all Dx + // ND is N * D + // size of the 0th dimension of thresholds // reshape to scalar for operations later. Operand numThresholds = tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); - // true if we will process thresholds as one-dimensional (possibly because we flatten them) + // if multilabel, then (rank(thresholds) == 1) + // else true Operand oneThresh; if (multiLabel) { oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds)); @@ -408,9 +423,9 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), LossTuple result = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions, tSampleWeight); - tPredictions = result.getTarget(); - tLabels = result.getLabels(); - tSampleWeight = result.getSampleWeights(); + tPredictions = result.getTarget(); // shape (N, Cx) + tLabels = result.getLabels(); // shape (N, Cx) + tSampleWeight = result.getSampleWeights(); // broadcastable to (N, Dx) if (!tPredictions.shape().isCompatibleWith(tLabels.shape())) throw new IllegalArgumentException( @@ -423,6 +438,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), } if (classIndex != null) { + // Slice to new shapes (N, Dx) tLabels = tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(1)); tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1)); } @@ -431,7 +447,8 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand numExamples = tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); - // number of labels (or predictions) per example + // number of labels (and predictions) per example (after possibly slicing by classIndex) + // In the notation we are using for comments, we'll call this D. Operand numLabels = tf.select( tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), @@ -442,22 +459,11 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), tf.constant(0))); - // If we will treat thresholds as one-dimensional (always true as of this writing), - // then threshLabelTile is the number of labels (or predictions) per sample. Else it is 1. + // threshLabelTile == numLabels except in one case: + // if multilabel and rank(thresholds) != 1, then threshLabelTile is 1 Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); - ///////// - // Tile data for threshold comparisons, which is a cross product of thresholds and - // predictions/labels. - // - // In the multilabel case, we want a data shape of (T, N, D0). - // else (T, ND). - // where T is numThresholds - // Dx == Cx except that D0 == 1 if classIndex != null - // ND is the product of N and all Dx. - // In these comments, we refer to all indices beyond the threshold index as a "data position". - - // if multilabel, then shape (1, N, D0) + // if multilabel, then shape (1, N, Dx) // else shape (1, ND), Operand predictionsExtraDim; Operand labelsExtraDim; @@ -499,17 +505,22 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), dataTiles = Arrays.asList(numThresholds, tf.constant(1)); } + // if multilabel, then shape (T, 1, T*) + // else shape (T, T*) + // where T* is the product of all threshold dimension sizes beyond 0 Operand thresholdsReshaped = tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); + Operand threshTilesShape = tf.stack(threshTiles); // if multilabel, then shape (T, N, threshLabelTile) // else (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); - // if multilabel, then shape (T, N, D0) - // else (T, ND) Operand dataTilesShape = tf.stack(dataTiles); + + // if multilabel, then shape (T, N, D) + // else (T, ND) Operand predsTiled = tf.tile(predictionsExtraDim, dataTilesShape); // Compare predictions and threshold. @@ -519,17 +530,30 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand weightsTiled; if (tSampleWeight != null) { tSampleWeight = tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); + // if multilabel, then + // reshape tSampleWeight to (1, N, threshLabelTile) + // tile the result into shape (T, N, threshLabelTile) + // where threshLabelTile is typically D + // else + // reshape tSampleWeight to (1, ND) + // tile the result into shape (T, ND) weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape); } else { weightsTiled = null; } if (labelWeights != null) { + // Change shape to (1, Dx). Operand lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0)); + // Broadcast to shape (N, Dx). lLabelWeights = tf.broadcastTo(lLabelWeights, tPredictions); + + // If multilabel: shape (T, N, D) + // else: shape (T, ND) Operand labelWeightsTiled = tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles)); + if (weightsTiled == null) { weightsTiled = labelWeightsTiled; } else { @@ -606,6 +630,10 @@ private static Operand weightedAssignAdd( if (weights != null) { labelAndPred = tf.math.mul(labelAndPred, weights); } + // if multilabel: + // sum across examples, leaving shape (T, D) + // else: + // sum across ND, leaving shape (T) Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); Operand assignAdd; if (initializer != null) { From b2937cdae7d8cbe6aaa01ea04cb50cc6c2c2e329 Mon Sep 17 00:00:00 2001 From: deansher Date: Fri, 19 Mar 2021 11:11:20 -0400 Subject: [PATCH 77/97] Tweaked internal docs in MetricsHelper.java --- .../tensorflow/framework/metrics/impl/MetricsHelper.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 5e29b8a1002..6d35d1a71c4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -448,7 +448,7 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); // number of labels (and predictions) per example (after possibly slicing by classIndex) - // In the notation we are using for comments, we'll call this D. + // In the notation we are using for comments, this is D. Operand numLabels = tf.select( tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), @@ -513,8 +513,10 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), Operand threshTilesShape = tf.stack(threshTiles); - // if multilabel, then shape (T, N, threshLabelTile) - // else (T, ND) + // if multilabel, then + // if thresholds has rank > 1, then shape (T, N, T*) + // else shape (T, N, D) + // else shape (T, ND) Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); Operand dataTilesShape = tf.stack(dataTiles); From d9c8352c33a03ba4451f730cdada4c558d5c219b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 16:21:24 -0400 Subject: [PATCH 78/97] Fix the documentation on TP, FP, TN, and FN --- .../framework/metrics/impl/ConfusionMatrixEnum.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java index bf3ade53d73..caa5f203f9f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java @@ -20,9 +20,9 @@ public enum ConfusionMatrixEnum { TRUE_POSITIVES("tp"), /** These are cases in which the prediction is true, and reality is false. */ FALSE_POSITIVES("fp"), - /** These are cases in which the prediction is false, and reality is true. */ - TRUE_NEGATIVES("tn"), /** These are cases in which the prediction is false, and reality is false. */ + TRUE_NEGATIVES("tn"), + /** These are cases in which the prediction is false, and reality is true. */ FALSE_NEGATIVES("fn"); private final String abbrev; From ce07c257a88deeba5874156cf10797a7ce95f9af Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 21 Mar 2021 16:50:09 -0400 Subject: [PATCH 79/97] Added code comments to fitlerTopK. --- .../org/tensorflow/framework/metrics/impl/MetricsHelper.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 6d35d1a71c4..d34a7b25111 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -664,7 +664,9 @@ private static Operand weightedAssignAdd( private static Operand filterTopK(Ops tf, Operand x, int topK) { Class type = x.type(); Shape xShape = x.shape(); + // top has the same rank as x; the last dimension becomes indices of the topK features. TopK top = tf.nn.topK(x, tf.constant(topK), TopK.sorted(false)); + // oneHot has an additional dimension: the one-hot representation of each topK index. OneHot oneHot = tf.oneHot( top.indices(), @@ -672,6 +674,7 @@ private static Operand filterTopK(Ops tf, Operand x, i tf.constant(1), tf.constant(0), OneHot.axis(-1L)); + // Sum the one-hot representations along the last dimension of x. Operand topKMask = cast(tf, tf.reduceSum(oneHot, tf.constant(-2)), type); // x * top_k_mask + NEG_INF * (1 - top_k_mask) From e9f1a3580bda8c5e8599e775e157032fdb38a324 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:33:51 -0400 Subject: [PATCH 80/97] JavaDoc fixes and code cleanup and add code comments --- .../org/tensorflow/framework/metrics/AUC.java | 18 ++++++++++-------- .../tensorflow/framework/metrics/MeanIoU.java | 14 +++++++++++--- .../framework/metrics/MeanTensor.java | 15 ++++++++++++--- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index b09154b46da..01200fd39b4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -792,10 +792,10 @@ private Operand positive(Operand input) { } /** - * Gets an operand that determines whether the input consists of each value is greater than 0. + * Gets the truth value of whether {@code input > 0}, element-wise. * * @param input the input - * @return an operand that determines whether the input consists of all values greater than 0. + * @return the truth value of whether {@code input > 0}, element-wise. */ private Operand isPositive(Operand input) { return getTF().math.greater(input, cast(getTF(), getTF().constant(0), input.type())); @@ -864,23 +864,25 @@ private Operand interpolatePRAuc() { Operand dTP = tf.math.sub(tp0, tp1); Operand p = tf.math.add(truePositives, falsePositives); + Operand p0= slice(p, 0, getNumThresholds() - 1); + Operand p1= slice(p, 1, -1); - Operand dP = tf.math.sub(slice(p, 0, getNumThresholds() - 1), slice(p, 1, -1)); + Operand dP = tf.math.sub(p0,p1); Operand precisionSlope = tf.math.divNoNan(dTP, positive(dP)); Operand intercept = - tf.math.sub(slice(truePositives, 1, -1), tf.math.mul(precisionSlope, slice(p, 1, -1))); + tf.math.sub(tp1, tf.math.mul(precisionSlope, p1)); Operand safePRatio = tf.select( tf.math.logicalAnd( - isPositive(slice(p, 0, getNumThresholds() - 1)), isPositive(slice(p, 1, -1))), + isPositive(p0), isPositive(p1)), tf.math.divNoNan( - slice(p, 0, getNumThresholds() - 1), - positive(slice(p, 1, -1))), - tf.onesLike(slice(p, 1, -1))); + p0, + positive(p1)), + tf.onesLike(p1)); Operand fn1 = slice(falseNegatives, 1, -1); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 19b13ed391c..3cd3fd7c0ee 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -124,13 +124,17 @@ public List updateStateList( Operand sampleWeights) { Operand tLabels = cast(getTF(), labels, type); - if (tLabels.shape().numDimensions() > 1) tLabels = getTF().shape.flatten(tLabels); + if (tLabels.shape().numDimensions() > 1) { + tLabels = getTF().shape.flatten(tLabels); + } Operand tPredictions = cast(getTF(), predictions, type); - if (tPredictions.shape().numDimensions() > 1) + if (tPredictions.shape().numDimensions() > 1) { tPredictions = getTF().shape.flatten(tPredictions); + } Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; - if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) + if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { tSampleWeights = getTF().shape.flatten(tSampleWeights); + } Operand currentCM = MetricsHelper.confusionMatrix( @@ -149,6 +153,10 @@ public Operand result() { totalConfusionMatrix, tf.constant(0), cast(tf, tf.constant(0), totalConfusionMatrix.type())); + // for each class, the total predictions + total labels - true positives + // Observe that total predictions = tp + fp + // total labels = tp + fn + // So this is 2*tp + fp + fn - tp = tp + fp + fn Operand denominator = tf.math.add(sumOverRow, tf.math.sub(sumOverCol, truePositives)); Operand numValidEntries = tf.reduceSum( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index 3d6d8194aac..f01cb47b256 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -103,14 +103,19 @@ private boolean init(Shape shape) { } } - /** {@inheritDoc} */ + /** + * Accumulates statistics for computing the element-wise mean. + * + * @param values Per-example value. Input values must always have the same shape for all + * invocations of updateStateList. + * @param sampleWeights Optional weighting of each example. Defaults to 1 if null. + */ @Override public List updateStateList( Operand values, Operand sampleWeights) { Ops tf = getTF(); Operand tValues = cast(tf, values, type); - Operand tSampleWeights = null; - if (sampleWeights != null) tSampleWeights = cast(tf, sampleWeights, type); + Operand tSampleWeights = sampleWeights == null ? null : cast(tf, sampleWeights, type); boolean needsInitialization = init(values.shape()); @@ -123,13 +128,17 @@ public List updateStateList( Operand numValues = tf.onesLike(tValues); if (tSampleWeights != null) { + //Update dimensions of weights to match with values if possible. LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights); tValues = tuple.getTarget(); tSampleWeights = tuple.getSampleWeights(); try { + // Broadcast weights if possible. tSampleWeights = WeightsBroadcastOps.broadcastWeights(tf, tSampleWeights, tValues); } catch (IllegalArgumentException ex) { + // sampleWeights cannot be broadcast to values + // Reduce values to same ndim as weight array int ndim = values.shape().numDimensions(); int weightNdim = tSampleWeights.asOutput().shape().numDimensions(); int[] range = new int[ndim - weightNdim]; From b1764322c15bab8cf40f4f2f39b43e0c2e0fa535 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:34:59 -0400 Subject: [PATCH 81/97] JavaDoc fixes and code cleanup and add code comments Remose shape flatten in updateStateList --- .../framework/metrics/MeanRelativeError.java | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index 4c48c0f88a7..b8cec2150b7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -33,8 +33,8 @@ * ultimately returned as mean relative error: an idempotent operation that simply divides total by * count. * - *

If sampleWeight is null, weights default to 1. Use sample_weight of - * 0 to mask * values. + *

If {@code sampleWeight} is null, weights default to 1. Use {@code sampleWeight} of + * 0 to mask values. * * @param The data type for the metric result */ @@ -124,27 +124,29 @@ protected MeanRelativeError( this.normalizer = normalizer; } - /** {@inheritDoc} */ + /** + * Accumulates metric statistics. + * + * @param labels The ground truth values. + * @param predictions The predicted values. Must be the same shape as the normalizer. + * @param sampleWeights Optional weighting of each example. A null value defaults to 1. Can be an {@code Operand} + * whose rank is either 0, or the same rank as {@code labels}, and must be broadcastable to + * {@code labels}. + * @return a List of Operations to update the metric state + */ @Override public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { Operand tLabels = cast(getTF(), labels, getResultType()); - if (tLabels.shape().numDimensions() > 1) tLabels = getTF().shape.flatten(tLabels); - Operand tPredictions = cast(getTF(), predictions, getResultType()); - if (tPredictions.shape().numDimensions() > 1) - tPredictions = getTF().shape.flatten(tPredictions); + Operand tSampleWeights = + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); tPredictions = tuple.getTarget(); tLabels = tuple.getLabels(); - Operand tSampleWeights = - sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; - if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { - tSampleWeights = getTF().shape.flatten(tSampleWeights); - } tuple = LossesHelper.removeSqueezableDimensions(getTF(), normalizer, tPredictions); normalizer = tuple.getLabels(); From d087a6f491dfe03e94390882ae436dc5755dc075 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:35:46 -0400 Subject: [PATCH 82/97] Fix code in sparseTopKCategoricalAccuracy to reshape to proper dimensions --- .../java/org/tensorflow/framework/metrics/Metrics.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index bcf2ea4c880..3d4c262491f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -23,7 +23,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; -/** Helper class with built-in metrics functions. */ +/** Static methods for computing metrics. */ public class Metrics { /** @@ -84,6 +84,7 @@ public static Operand sparseTopKCatego Ops tf, Operand labels, Operand predictions, int k) { Operand tLabels = cast(tf, labels, predictions.type()); + // Flatten predictions to (batch_size, num_samples) and labels to (num_samples,) int predictionsRank = predictions.shape().numDimensions(); int labelsRank = tLabels.shape().numDimensions(); @@ -91,10 +92,13 @@ public static Operand sparseTopKCatego Operand castPredictions = cast(tf, predictions, TFloat32.class); if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { if (predictionsRank > 2) { - castPredictions = tf.shape.reduceDims(castPredictions, tf.constant(1)); + //y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) + castPredictions = tf.reshape(castPredictions, + tf.constant(castPredictions.shape().takeLast(1).prepend(Shape.UNKNOWN_SIZE))); } if (labelsRank > 1) { - tLabels = tf.shape.flatten(tLabels); + //y_true = array_ops.reshape(y_true, [-1]) + tLabels = tf.reshape(tLabels, tf.constant(Shape.of(Shape.UNKNOWN_SIZE))); } } return cast( From 6aec4ff90d46452d363f3f865b621ac584b1e042 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:36:14 -0400 Subject: [PATCH 83/97] Fix JavaDoc --- .../metrics/SparseTopKCategoricalAccuracy.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java index 7db290530cd..0fd600b4a0f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -22,7 +22,10 @@ import static org.tensorflow.framework.utils.CastHelper.cast; -/** @param The data type for the metric result */ +/** + * Computes how often integer targets are in the top `K` predictions. + * @param The data type for the metric result + * */ public class SparseTopKCategoricalAccuracy extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_K = 5; @@ -30,8 +33,7 @@ public class SparseTopKCategoricalAccuracy extends MeanMetric private final int k; /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of - * top elements to look at for computing accuracy. + * Creates a SparseTopKCategoricalAccuracy metric using {@link #DEFAULT_K} for the number of top elements. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. @@ -44,7 +46,7 @@ public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class ty } /** - * Creates a TopKCategoricalAccuracy metric + * Creates a SparseTopKCategoricalAccuracy metric. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. From 9090e318d1b75f8c8878fc15f2f011c29cdbf283 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 16:37:48 -0400 Subject: [PATCH 84/97] Fix JavaDoc --- .../framework/metrics/impl/MetricsHelper.java | 34 +++++++++++++++++-- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index d34a7b25111..c06616a6324 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -764,10 +764,37 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( /** * Computes the confusion matrix from predictions and labels. * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape {@code [n, n]}, where {@code n} is the + * number of valid labels for a given classification task. Both prediction and labels must be 1-D + * arrays of the same shape in order for this function to work. + * + *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum + * value in either predictions or labels. Class labels are expected to start at 0. For example, if + * {@code numClasses}` is 3, then the possible labels would be {@code [0, 1, 2]}. + * + *

If {@code weights} is not null, then each prediction contributes its corresponding weight to + * the total value of the confusion matrix cell. + * + *

For example: + * + *

+   *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
+   *          [[0 0 0 0 0]
+   *           [0 0 1 0 0]
+   *           [0 0 1 0 0]
+   *           [0 0 0 0 0]
+   *           [0 0 0 0 1]]
+   * 
+ * + * Note that the possible labels are assumed to be {@copde [0, 1, 2, 3,4]}, resulting in a 5x5 + * confusion matrix. + * * @param tf the TensorFlow Ops - * @param labels 1-D `Tensor` of real labels for the classification task. - * @param predictions 1-D `Tensor` of predictions for a given classification. - * @param numClasses The possible number of labels the classification task can have. + * @param labels 1-D {@code Operand} of real labels for the classification task. + * @param predictions 1-D {@code Operand} of predictions for a given classification. + * @param numClasses The possible number of labels the classification task can have. If this value + * is not provided, it will be calculated using both predictions and labels array. * @param weights optional weights to be applied to the confusion matrix * @param type Data type of the confusion matrix. * @param the type of Operands @@ -778,6 +805,7 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( * not have compatible shapes, or if weights is notnull and its * shape is not compatible with predictions. */ + // TODO should this be moved to FramnworkOps under math. public static Operand confusionMatrix( Ops tf, Operand labels, From 4e5906c8b93d2cf511b53bdf66e37aa6b5f90190 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 19:22:10 -0400 Subject: [PATCH 85/97] Fixed Javadoc, mainly to add shape requirements. Reformat code --- .../org/tensorflow/framework/metrics/AUC.java | 75 +++++++++---------- .../framework/metrics/AUCSummationMethod.java | 6 +- .../framework/metrics/Accuracy.java | 11 ++- .../framework/metrics/BinaryAccuracy.java | 8 +- .../framework/metrics/BinaryCrossentropy.java | 9 ++- .../metrics/CategoricalAccuracy.java | 16 +++- .../metrics/CategoricalCrossentropy.java | 8 +- .../framework/metrics/CategoricalHinge.java | 8 +- .../framework/metrics/CosineSimilarity.java | 18 ++++- .../tensorflow/framework/metrics/Hinge.java | 8 +- .../framework/metrics/KLDivergence.java | 8 +- .../framework/metrics/LogCoshError.java | 8 +- .../framework/metrics/MeanAbsoluteError.java | 8 +- .../metrics/MeanAbsolutePercentageError.java | 8 +- .../tensorflow/framework/metrics/MeanIoU.java | 15 +++- .../framework/metrics/MeanRelativeError.java | 12 +-- .../framework/metrics/MeanSquaredError.java | 21 +++++- .../metrics/MeanSquaredLogarithmicError.java | 8 +- .../framework/metrics/MeanTensor.java | 5 +- .../tensorflow/framework/metrics/Metrics.java | 8 +- .../tensorflow/framework/metrics/Poisson.java | 8 +- .../framework/metrics/Precision.java | 18 ++++- .../framework/metrics/PrecisionAtRecall.java | 1 + .../tensorflow/framework/metrics/Recall.java | 18 ++++- .../metrics/RootMeanSquaredError.java | 10 ++- .../metrics/SparseCategoricalAccuracy.java | 9 ++- .../SparseCategoricalCrossentropy.java | 13 +++- .../SparseTopKCategoricalAccuracy.java | 14 +++- .../framework/metrics/SquaredHinge.java | 10 ++- .../metrics/TopKCategoricalAccuracy.java | 8 +- .../impl/ConfusionMatrixConditionCount.java | 11 ++- .../framework/metrics/impl/MetricsHelper.java | 50 ++++++------- .../framework/metrics/impl/Reduce.java | 12 ++- .../impl/SensitivitySpecificityBase.java | 21 +++++- 34 files changed, 357 insertions(+), 114 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 01200fd39b4..3dbc6f22cec 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -715,11 +715,11 @@ private Map> build(Shape shape) { * * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be - * cast to T. If {@link #multiLabel} or if {@link #labelWeights} != null + * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} != null * , then Cx must be a single dimension. * @param predictions the predictions shape (N, Cx, P1?). Will be cast to T. * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to - * T. + * {@code }. * @return a List of Operations to update the metric state */ @Override @@ -795,7 +795,7 @@ private Operand positive(Operand input) { * Gets the truth value of whether {@code input > 0}, element-wise. * * @param input the input - * @return the truth value of whether {@code input > 0}, element-wise. + * @return the truth value of whether {@code input > 0}, element-wise. */ private Operand isPositive(Operand input) { return getTF().math.greater(input, cast(getTF(), getTF().constant(0), input.type())); @@ -818,41 +818,52 @@ private Operand slice(Operand input, int begin, int size) { * Interpolation formula inspired by section 4 of Davis & Goadrich 2006. * *

Note here we derive & use a closed formula not present in the paper as follows: + * *

    *     Precision = TP / (TP + FP) = TP / P
    * 
- *

Modeling all of TP (true positive), FP (false positive) and their sum - * P = TP + FP (predicted positive) as varying linearly within each interval - * [A, B] between successive thresholds, we get

+ * + *

Modeling all of TP (true positive), FP (false positive) and their sum P = TP + FP (predicted + * positive) as varying linearly within each interval [A, B] between successive thresholds, we get + * *

    *     Precision slope = dTP / dP
    *                     = (TP_B - TP_A) / (P_B - P_A)
    *                     = (TP - TP_A) / (P - P_A)
    *     Precision = (TP_A + slope * (P - P_A)) / P
    * 
+ * *

The area within the interval is (slope / total_pos_weight) times + * *

    *       int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
    *       int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
    * 
- * where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + * + * where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + * *
    *       int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
    * 
- * Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + * + * Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + * *
    *       slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
    * 
- * where dTP == TP_B - TP_A. - * Note that when P_A == 0 the above calculation simplifies into + * + * where dTP == TP_B - TP_A. Note that when P_A == 0 the above calculation simplifies into + * *
    *       int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
    * 
- * which is really equivalent to imputing constant precision throughout the - * first bucket having >0 true positives. + * + * which is really equivalent to imputing constant precision throughout the first bucket having >0 + * true positives. * * @return an approximation of the area under the P-R curve. - * @see The Relationship Between Precision-Recall and ROC Curves - Davis & Goadrich 2006 + * @see The Relationship Between + * Precision-Recall and ROC Curves - Davis & Goadrich 2006 */ private Operand interpolatePRAuc() { // truePositives[:self.numThresholds - 1] @@ -864,24 +875,19 @@ private Operand interpolatePRAuc() { Operand dTP = tf.math.sub(tp0, tp1); Operand p = tf.math.add(truePositives, falsePositives); - Operand p0= slice(p, 0, getNumThresholds() - 1); - Operand p1= slice(p, 1, -1); + Operand p0 = slice(p, 0, getNumThresholds() - 1); + Operand p1 = slice(p, 1, -1); - Operand dP = tf.math.sub(p0,p1); + Operand dP = tf.math.sub(p0, p1); - Operand precisionSlope = - tf.math.divNoNan(dTP, positive(dP)); + Operand precisionSlope = tf.math.divNoNan(dTP, positive(dP)); - Operand intercept = - tf.math.sub(tp1, tf.math.mul(precisionSlope, p1)); + Operand intercept = tf.math.sub(tp1, tf.math.mul(precisionSlope, p1)); Operand safePRatio = tf.select( - tf.math.logicalAnd( - isPositive(p0), isPositive(p1)), - tf.math.divNoNan( - p0, - positive(p1)), + tf.math.logicalAnd(isPositive(p0), isPositive(p1)), + tf.math.divNoNan(p0, positive(p1)), tf.onesLike(p1)); Operand fn1 = slice(falseNegatives, 1, -1); @@ -890,15 +896,12 @@ private Operand interpolatePRAuc() { tf.math.mul( precisionSlope, tf.math.add(dTP, tf.math.mul(intercept, tf.math.log(safePRatio)))); - Operand prAucIncrement = - tf.math.divNoNan( - aucTotalPos, - positive(tf.math.add(tp1, fn1))); + Operand prAucIncrement = tf.math.divNoNan(aucTotalPos, positive(tf.math.add(tp1, fn1))); if (isMultiLabel()) { Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0)); if (getLabelWeights() == null) { - //Evenly weighted average of the label AUCs. + // Evenly weighted average of the label AUCs. return MetricsHelper.mean(tf, byLabelAuc); } else { // Weighted average of the label AUCs. @@ -943,7 +946,6 @@ public Operand result() { // y[1:] Operand ySlice2 = slice(y, 1, -1); - Operand heights; switch (getSummationMethod()) { case INTERPOLATION: @@ -956,16 +958,13 @@ public Operand result() { heights = tf.math.maximum(ySlice1, ySlice2); break; default: - throw new IllegalArgumentException("Unexpected AUCSummationMethod value: " + getSummationMethod()); + throw new IllegalArgumentException( + "Unexpected AUCSummationMethod value: " + getSummationMethod()); } if (isMultiLabel()) { Operand riemannTerms = - tf.math.mul( - tf.math.sub( - slice(x, 0, getNumThresholds() - 1), - slice(x, 1, -1)), - heights); + tf.math.mul(tf.math.sub(slice(x, 0, getNumThresholds() - 1), slice(x, 1, -1)), heights); Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0)); if (getLabelWeights() == null) { @@ -979,7 +978,7 @@ public Operand result() { } } else { - Operand slice1 = slice(x,0, getNumThresholds() - 1); + Operand slice1 = slice(x, 0, getNumThresholds() - 1); Operand slice2 = slice(x, 1, -1); Operand sub = tf.math.sub(slice1, slice2); Operand operand = tf.math.mul(sub, heights); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java index 3887f687eea..735d97ecf09 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java @@ -21,8 +21,10 @@ * @see Riemann summation method */ public enum AUCSummationMethod { - /** Apply mid-point summation scheme for {@link AUCCurve#ROC}. For {@link AUCCurve#PR}, interpolates (true/false) positives but not the ratio that - * is precision */ + /** + * Apply mid-point summation scheme for {@link AUCCurve#ROC}. For {@link AUCCurve#PR}, + * interpolates (true/false) positives but not the ratio that is precision + */ INTERPOLATION, /** Apply right summation for increasing intervals and left summation for decreasing intervals */ MAJORING, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java index 9548fb42c65..30787a9889b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.metrics.impl.LossMetric; import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; @@ -65,7 +66,15 @@ public Accuracy(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Calculates how often predictions equals labels. {@code labels} and {@code predictions} must + * have compatible shapes, see {@link Shape @isCompatibleWith}. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @throws IllegalArgumentException if predictions and labels shapes are not compatible. + * @return the loss + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index d2a414fdeb7..4f9a267d633 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -85,7 +85,13 @@ public BinaryAccuracy(Ops tf, String name, float threshold, long seed, Class setLoss(this); } - /** {@inheritDoc} */ + /** + * Calculates how often predictions match binary labels. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Binary accuracy values. shape = {@code [batch_size, d0, .. dN-1]} + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 48ee244eafb..57a6f75375d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -60,7 +60,14 @@ public BinaryCrossentropy( this.labelSmoothing = labelSmoothing; } - /** {@inheritDoc} */ + /** + * Computes the binary crossentropy loss between labels and predictions. + * + * @param labels the truth values or labels, has the same shape as predictions and shape = {@code + * [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Binary crossentropy loss value. shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index c3780cc6de2..55c3dc800e1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -27,8 +27,8 @@ /** * Metric that calculates how often predictions matches one-hot labels. * - *

You can provide logits of classes as predictions, since argmax - * of logits and probabilities are same. + *

You can provide logits of classes as predictions, since argmax of + * logits and probabilities are same. * *

This metric creates two local variables, total and count that are * used to compute the frequency with which predictions matches labels. @@ -73,7 +73,17 @@ public CategoricalAccuracy(Ops tf, String name, long seed, Class type) { super.setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the categorical crossentropy loss. + * + *

{@code predictions} and {@code labels} should be passed in as vectors of probabilities, + * rather than as labels. If necessary, use {@line Ops#oneHot} to expand {@code labels} as a + * vector. + * + * @param labels One-hot ground truth values. + * @param predictions tThe prediction values. + * @return Categorical accuracy values. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index b22e5415f79..a7e85ce5b02 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -99,7 +99,13 @@ public CategoricalCrossentropy( this.axis = axis; } - /** {@inheritDoc} */ + /** + * Computes the crossentropy loss between the labels and predictions. + * + * @param labels the truth values or labels, of one-hot true targets, same shape as predictions + * @param predictions the predictions + * @return Categorical crossentropy loss value. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 4266cc487c0..1f6d0fd002c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -45,7 +45,13 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the categorical hinge metric between {@code labels} and @{code predictions}. + * + * @param labels the truth values or labels, labels values are expected to be 0 or 1. + * @param predictions the predictions + * @return Categorical hinge loss values. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 840f255c5ab..230286a738f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -26,7 +26,17 @@ /** * A metric that computes the cosine similarity metric between labels and predictions. * + *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 + * indicates orthogonality and values closer to -1 indicate greater similarity. The values closer to + * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you + * try to maximize the proximity between predictions and targets. If either labels and predictions + * is a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and + * targets. + * + *

{@code loss = -sum(l2_norm(y_true) * l2_norm(y_pred))}
+ * * @param The data type for the metric result. + * @see Cosine Similarity */ public class CosineSimilarity extends MeanMetricWrapper implements LossMetric { @@ -76,7 +86,13 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the cosine similarity loss between labels and predictions. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @return the cosine similarity loss + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 46ccd2859ff..a2d110867b8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -44,7 +44,13 @@ public Hinge(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the hinge loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return the hinge loss between labels and predictions. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 9ffcd6189f1..155a891ccc2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -45,7 +45,13 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes Kullback-Leibler divergence metric between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return the loss with shape {@code [batch_size, d0, .. dN-1]} + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 59e24f57110..786847d4b32 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -45,7 +45,13 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Calculates the Logarithm of the hyperbolic cosine of the prediction error. + * + * @param labels Ground truth values, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Logcosh error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 1cc6d0b6f99..b38d0a809e1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -45,7 +45,13 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean absolute error loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Mean absolute error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 8c6720b58f6..22bcd0ab0eb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -45,7 +45,13 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean absolute percentage error loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Mean absolute percentage error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 3cd3fd7c0ee..70c2e6db8f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -116,7 +116,15 @@ public Assign getInitializer() { return initializer; } - /** {@inheritDoc} */ + /** + * Accumulates the confusion matrix statistics. + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either + * 0, or the same rank as labels, and must be broadcastable to labels. + * @return the Operands that updates totalConfusionMatrix variable + */ @Override public List updateStateList( Operand labels, @@ -130,12 +138,13 @@ public List updateStateList( Operand tPredictions = cast(getTF(), predictions, type); if (tPredictions.shape().numDimensions() > 1) { tPredictions = getTF().shape.flatten(tPredictions); - } + } Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { tSampleWeights = getTF().shape.flatten(tSampleWeights); - } + } + // Accumulate the prediction to current confusion matrix. Operand currentCM = MetricsHelper.confusionMatrix( getTF(), tLabels, tPredictions, getTF().constant(numClasses), tSampleWeights, type); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index b8cec2150b7..ac25183c0e5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -33,8 +33,8 @@ * ultimately returned as mean relative error: an idempotent operation that simply divides total by * count. * - *

If {@code sampleWeight} is null, weights default to 1. Use {@code sampleWeight} of - * 0 to mask values. + *

If {@code sampleWeight} is null, weights default to 1. Use {@code sampleWeight} + * of 0 to mask values. * * @param The data type for the metric result */ @@ -129,9 +129,9 @@ protected MeanRelativeError( * * @param labels The ground truth values. * @param predictions The predicted values. Must be the same shape as the normalizer. - * @param sampleWeights Optional weighting of each example. A null value defaults to 1. Can be an {@code Operand} - * whose rank is either 0, or the same rank as {@code labels}, and must be broadcastable to - * {@code labels}. + * @param sampleWeights Optional weighting of each example. A null value defaults to 1. Can be an + * {@code Operand} whose rank is either 0, or the same rank as {@code labels}, and must be + * broadcastable to {@code labels}. * @return a List of Operations to update the metric state */ @Override @@ -142,7 +142,7 @@ public List updateStateList( Operand tLabels = cast(getTF(), labels, getResultType()); Operand tPredictions = cast(getTF(), predictions, getResultType()); Operand tSampleWeights = - sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); tPredictions = tuple.getTarget(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 3c4c79d39ba..fd8be29875e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -26,6 +26,19 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * + *

The {@code MeanSquaredError} class creates two local variables, {@code total} and {@code + * count} that are used to compute the mean squared error. This average is weighted by {@code + * weights}, and it is ultimately returned as the mean squared error: an idempotent operation that + * simply divides {@code total} by {@code count}. + * + *

For estimation of the metric over a stream of data, the function creates an update operation + * that updates these variables. Internally, a squared error operation computes the element-wise + * square of the difference between {@code predictions} and {@code labels}. Then the update + * operation increments {@code total} with the reduced sum of the product of {@code weights} and the + * squared error, and it increments {@code count} with the reduced sum of {@code weights}. + * + *

If {@code weights} is null, weights default to 1. Use weights of 0 to mask values. + * * @param The data type for the metric result. */ public class MeanSquaredError extends MeanMetricWrapper @@ -45,7 +58,13 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean squared error between the labels and predictions. + * + * @param labels the truth values or labels. Must be the same shape as predictions. + * @param predictions the predictions + * @return Computes the mean squared error between the labels and predictions. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index d525bb76648..4728cbab12f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -45,7 +45,13 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean squared logarithmic error between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Mean squared logarithmic error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java index f01cb47b256..d88d7a4c1b4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -109,6 +109,8 @@ private boolean init(Shape shape) { * @param values Per-example value. Input values must always have the same shape for all * invocations of updateStateList. * @param sampleWeights Optional weighting of each example. Defaults to 1 if null. + * @throws IllegalArgumentException if the shape of {@code values} in each subsequent call is not + * the same shape as {@code values} set during the first call */ @Override public List updateStateList( @@ -117,6 +119,7 @@ public List updateStateList( Operand tValues = cast(tf, values, type); Operand tSampleWeights = sampleWeights == null ? null : cast(tf, sampleWeights, type); + // update the shape if it is the first call. boolean needsInitialization = init(values.shape()); if (!this.shape.equals(values.shape())) { @@ -128,7 +131,7 @@ public List updateStateList( Operand numValues = tf.onesLike(tValues); if (tSampleWeights != null) { - //Update dimensions of weights to match with values if possible. + // Update dimensions of weights to match with values if possible. LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights); tValues = tuple.getTarget(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 3d4c262491f..a33750ac3f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -92,12 +92,14 @@ public static Operand sparseTopKCatego Operand castPredictions = cast(tf, predictions, TFloat32.class); if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { if (predictionsRank > 2) { - //y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) - castPredictions = tf.reshape(castPredictions, + // y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) + castPredictions = + tf.reshape( + castPredictions, tf.constant(castPredictions.shape().takeLast(1).prepend(Shape.UNKNOWN_SIZE))); } if (labelsRank > 1) { - //y_true = array_ops.reshape(y_true, [-1]) + // y_true = array_ops.reshape(y_true, [-1]) tLabels = tf.reshape(tLabels, tf.constant(Shape.of(Shape.UNKNOWN_SIZE))); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 422fd4808ff..2e4bde8ec55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -44,7 +44,13 @@ public Poisson(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the Poisson loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Poisson loss value, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index bd536f16b29..5784cf46385 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -22,9 +22,14 @@ import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -290,7 +295,16 @@ private void init() { } } - /** {@inheritDoc} */ + /** + * Accumulates true positive and false positive statistics. + * + * @param labels the labels The ground truth values, with the same dimensions as predictions. Will + * be cast to {@link TBool}. + * @param predictions the predictions, each element must be in the range {@code [0, 1]}. + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override @SuppressWarnings("unchecked") public List updateStateList( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 299c649279f..4205f761e4b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -110,6 +110,7 @@ public PrecisionAtRecall( this.recall = recall; } + /** {@inheritDoc} */ @Override public Operand result() { Ops tf = getTF(); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index 54e9de0d9cf..ca5968d4f9d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -22,9 +22,14 @@ import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -321,7 +326,16 @@ public Op resetStates() { return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); } - /** {@inheritDoc} */ + /** + * Accumulates true positive and false negative statistics. + * + * @param labels the labels The ground truth values, with the same dimensions as predictions. Will + * be cast to {@link TBool}. + * @param predictions the predictions, each element must be in the range {@code [0, 1]}. + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override @SuppressWarnings("unchecked") public List updateStateList( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 9b4401964d7..721b95487c7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -60,7 +60,15 @@ public RootMeanSquaredError(Ops tf, String name, long seed, Class type) { super(tf, name, seed, type); } - /** {@inheritDoc} */ + /** + * Accumulates root mean squared error statistics. + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override public List updateStateList( Operand labels, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 7bfa7fd6ee9..6dfdab48578 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -94,7 +94,13 @@ public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) super.setLoss(this); } - /** {@inheritDoc} */ + /** + * Calculates how often predictions matches integer labels. + * + * @param labels Integer ground truth values. + * @param predictions the predictions + * @return Sparse categorical accuracy values. + */ @Override public Operand call( Operand labels, Operand predictions) { @@ -106,6 +112,7 @@ public Operand call( long predictionsRank = predShape.numDimensions(); long labelsRank = labelsShape.numDimensions(); + // If the shape of labels is (num_samples, 1), squeeze to (num_samples,) if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE && labelsShape.size((int) labelsRank - 1) == 1) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index 9949f0c6b60..04555d85b66 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -25,7 +25,10 @@ /** * A metric that computes the sparse categorical cross-entropy loss between true labels and - * predicted labels. \ + * predicted labels. + * + *

You can provide logits of classes as predictions, since argmax of logits and probabilities are + * same. * * @param The data type for the metric result. */ @@ -55,7 +58,13 @@ public SparseCategoricalCrossentropy( this.axis = axis; } - /** {@inheritDoc} */ + /** + * Calculates how often predictions matches integer labels. + * + * @param labels Integer ground truth values. + * @param predictions the predictions + * @return Sparse categorical accuracy values. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java index 0fd600b4a0f..29dc91298d3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -24,8 +24,9 @@ /** * Computes how often integer targets are in the top `K` predictions. + * * @param The data type for the metric result - * */ + */ public class SparseTopKCategoricalAccuracy extends MeanMetricWrapper implements LossMetric { public static final int DEFAULT_K = 5; @@ -33,7 +34,8 @@ public class SparseTopKCategoricalAccuracy extends MeanMetric private final int k; /** - * Creates a SparseTopKCategoricalAccuracy metric using {@link #DEFAULT_K} for the number of top elements. + * Creates a SparseTopKCategoricalAccuracy metric using {@link #DEFAULT_K} for the number of top + * elements. * * @param tf the TensorFlow Ops * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. @@ -61,7 +63,13 @@ public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Clas setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes how often integer targets are in the top {@code K} predictions. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @return Sparse top K categorical accuracy value. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 19b3b1d0ac4..e2ff208b8f5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -44,7 +44,15 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the squared hinge loss between labels and predictions. + * + * @param labels The ground truth values. {@code labels} values are expected to be -1 or 1. If + * binary (0 or 1) labels are provided we will convert them to -1 or 1. shape = {@code + * [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Squared hinge loss values. shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index ad78e48bc34..9c8d6403a6b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -63,7 +63,13 @@ public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class t setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes how often targets are in the top {@code K} predictions. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @return Top K categorical accuracy value. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index 31e88b6bb31..63ea35df7f2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -114,6 +114,7 @@ public ConfusionMatrixConditionCount( init(); } + /** Initialize the metric */ private void init() { Shape variableShape = Shape.of(this.thresholds.length); @@ -134,7 +135,15 @@ public Assign getInitializer() { return initializer; } - /** {@inheritDoc} */ + /** + * Accumulates the metric statistics. + * + * @param labels The ground truth values. + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override public List updateStateList( Operand labels, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index c06616a6324..f36aaa34d8f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -22,7 +22,6 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; - import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.OneHot; import org.tensorflow.op.core.Rank; @@ -30,7 +29,6 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; import org.tensorflow.op.nn.TopK; - import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -269,7 +267,8 @@ public static List assertShapes( Operand size = dict.get(s); if (size == null) { // save size for later checks - size = tf.shape.size( symbol.getOperand(), tf.constant(ll.get()), TInt64.class); + size = + tf.shape.size(symbol.getOperand(), tf.constant(ll.get()), TInt64.class); dict.put(s, size); } Op assertion = @@ -280,7 +279,7 @@ public static List assertShapes( symbol.getOperand(), tf.constant(ll.getAndIncrement()), TInt64.class), - size), + size), Collections.singletonList(tf.constant(message))); updateOperations.add(assertion); }); @@ -305,47 +304,48 @@ public static List assertShapes( * will repeat the same for every threshold. * *

For estimation of these metrics over a stream of data, the function creates an `update_op` - * operation that updates the given variables.

+ * operation that updates the given variables. * *

labels, predictions, and sampleWeight tensors are - * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. - * sampleWeight is then broadcast to the shape of predictions.

+ * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. + * sampleWeight is then broadcast to the shape of predictions. * * @param tf the TensorFlow Ops * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. If multiLabel, then the - * variable shapes are (T, D), where T is the number of thresholds and D is the number of - * classes (after slicing by classIndex, if provided). - * If multiLabels, then the variable shapes are (T). + * corresponding variables to update as values. If multiLabel, then the variable + * shapes are (T, D), where T is the number of thresholds and D is the number of classes + * (after slicing by classIndex, if provided). If multiLabels, then + * the variable shapes are (T). * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and * corresponding initializer Operands to for variablesToUpdate. * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra * dimension of size 1 that would be squeezed. * @param predictions the predictions shape (N, Cx, P1?) - * @param thresholds thresholds in the range [0, 1], or {@link #NEG_INF} is used - * when topK is set - * @param topK optional, indicates that only the top k predictions should be considered. - * Applied before possibly slicing by classIndex. - * @param classIndex optional, limits the prediction and labels to the specified class. - * This is an integer index into the first dimension of Cx. - * @param sampleWeight optional Tensor that is aligned with labels and predictions - * as explained above. Use weights of 0 to mask values. + * @param thresholds thresholds in the range [0, 1], or {@link #NEG_INF} is used when + * topK is set + * @param topK optional, indicates that only the top k predictions should be considered. Applied + * before possibly slicing by classIndex. + * @param classIndex optional, limits the prediction and labels to the specified class. This is an + * integer index into the first dimension of Cx. + * @param sampleWeight optional Tensor that is aligned with labels and predictions as + * explained above. Use weights of 0 to mask values. * @param multiLabel indicates whether multidimensional prediction/labels should be treated as * multilabel responses, or flattened into a single label. When true, the values of * variablesToUpdate must have a second dimension equal to the number of labels and * predictions per example, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES - * without explicit multilabel handling (i.e. when the data is to be flattened). - * Must have shape (Dx), which is the same as (Cx) referenced above, except that if - * classIndex is provided, then the final dimension of Dx is 1. These weights - * will be broadcast across the 0th dimension (the examples dimension) of - * predictions. May be null. Must be null if multiLabel. + * without explicit multilabel handling (i.e. when the data is to be flattened). Must have + * shape (Dx), which is the same as (Cx) referenced above, except that if classIndex + * is provided, then the final dimension of Dx is 1. These weights will be broadcast + * across the 0th dimension (the examples dimension) of predictions. May be null. + * Must be null if multiLabel. * @param the data type for the variables * @throws IllegalArgumentException If predictions and labels have * mismatched shapes, or if sampleWeight is not nulland its shape - * doesn't match predictions, or if multiLabel && labelWeights != null. + * doesn't match predictions, or if multiLabel && labelWeights != null + * . * @return an op to update the given confusion matrix variables. */ @SuppressWarnings({"unchecked", "rawtypes"}) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 2a26967b9f2..3b54ad2e08d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -106,11 +106,11 @@ public Op resetStates() { } /** - * Updates the metric variables based on the inputs. At least one input arg required for - * values, an optional additional input for the sampleWeights + * Updates the metric variables based on the inputs. At least one input arg required for {@}code + * values}, an optional additional input for the sampleWeights * * @param values the inputs to be passed to update state, this may not be null - * @param sampleWeights sample weights to be applied to values, may be null. + * @param sampleWeights sample weights to be applied to values, will default to 1 if null. * @return the result with a control dependency on update state Operands * @throws IllegalArgumentException if values is null */ @@ -129,13 +129,16 @@ public List updateStateList( if (sampleWeights != null) { tSampleWeights = cast(getTF(), sampleWeights, getResultType()); + // Update dimensions of weights to match with values if possible. LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, tSampleWeights); tValues = tuple.getTarget(); tSampleWeights = tuple.getSampleWeights(); try { + // Broadcast weights if possible tSampleWeights = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); } catch (IllegalArgumentException ex) { + // reduce values to same ndim as weight array // if we get here we have static shapes with either // different ranks or different dimension sizes. // first, reduce the values down to the rank of the samples @@ -162,7 +165,9 @@ public List updateStateList( getTF().assignAdd(total, cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; + // Exit early if the reduction doesn't have a denominator. if (reduction != MetricReduction.SUM) { + // Update `count` for reductions that require a denominator. switch (reduction) { case SUM_OVER_BATCH_SIZE: numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); @@ -183,6 +188,7 @@ public List updateStateList( throw new UnsupportedOperationException( String.format("reduction [%s] not implemented", reduction)); } + Operand totalCount = getTF().assignAdd(this.count, numValues); updateOperations.add(totalCount); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 84898d8a4d3..08b298294ac 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -10,7 +10,11 @@ import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TNumber; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -111,6 +115,11 @@ private void init() { } } + /** + * Gets a control dependency Op to initialize all the variables + * + * @return a control dependency Op to initialize all the variables + */ public Op initializeVariables() { List varInitializers = new ArrayList<>(); @@ -130,7 +139,15 @@ public Op initializeVariables() { return getTF().withControlDependencies(varInitializers).noOp(); } - /** {@inheritDoc} */ + /** + * Accumulates confusion matrix statistics. + * + * @param labels The ground truth values. + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ @Override @SuppressWarnings("unchecked") public List updateStateList( From 980bb641bc3c45522472bc9a5eb32d3c9452ae33 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 1 Apr 2021 19:45:26 -0400 Subject: [PATCH 86/97] Fixed Javadoc errors. Change all xxxxxx to {@code xxxxxx} --- .../org/tensorflow/framework/metrics/AUC.java | 102 +++++++------- .../framework/metrics/Accuracy.java | 4 +- .../framework/metrics/BinaryAccuracy.java | 4 +- .../metrics/CategoricalAccuracy.java | 14 +- .../metrics/CategoricalCrossentropy.java | 26 ++-- .../framework/metrics/FalseNegatives.java | 24 ++-- .../framework/metrics/FalsePositives.java | 24 ++-- .../tensorflow/framework/metrics/MeanIoU.java | 6 +- .../framework/metrics/MeanRelativeError.java | 6 +- .../framework/metrics/Precision.java | 36 ++--- .../framework/metrics/PrecisionAtRecall.java | 2 +- .../tensorflow/framework/metrics/Recall.java | 18 +-- .../framework/metrics/RecallAtPrecision.java | 2 +- .../metrics/RootMeanSquaredError.java | 2 +- .../metrics/SensitivityAtSpecificity.java | 14 +- .../metrics/SparseCategoricalAccuracy.java | 4 +- .../metrics/SpecificityAtSensitivity.java | 16 +-- .../org/tensorflow/framework/metrics/Sum.java | 6 +- .../metrics/TopKCategoricalAccuracy.java | 2 +- .../framework/metrics/TrueNegatives.java | 24 ++-- .../framework/metrics/TruePositives.java | 24 ++-- .../impl/ConfusionMatrixConditionCount.java | 8 +- .../framework/metrics/impl/LossMetric.java | 2 +- .../metrics/impl/MeanMetricWrapper.java | 4 +- .../framework/metrics/impl/MetricsHelper.java | 129 +++++++++--------- .../framework/metrics/impl/Reduce.java | 4 +- .../framework/metrics/impl/SetsOps.java | 42 +++--- .../metrics/impl/WeightsBroadcastOps.java | 10 +- 28 files changed, 279 insertions(+), 280 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java index 3dbc6f22cec..bc5047d5855 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -40,25 +40,25 @@ /** * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. * - *

This metric creates four local variables, truePositives, trueNegatives - * , falsePositives and falseNegatives that are used to compute the + *

This metric creates four local variables, {@code truePositives}, {@code trueNegatives + * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the * AUC. To discretize the AUC curve, a linearly spaced set of thresholds is used to compute pairs of * recall and precision values. The area under the ROC-curve is therefore computed using the height * of the recall values by the false positive rate, while the area under the PR-curve is the * computed using the height of the precision values by the recall. * - *

This value is ultimately returned as auc, an idempotent operation that computes + *

This value is ultimately returned as {@code auc}, an idempotent operation that computes * the area under a discretized curve of precision versus recall values (computed using the - * aforementioned variables). The numThresholds variable controls the degree of + * aforementioned variables). The {@code numThresholds} variable controls the degree of * discretization with larger numbers of thresholds more closely approximating the true AUC. The - * quality of the approximation may vary dramatically depending on numThresholds. The - * thresholds parameter can be used to manually specify thresholds which split the + * quality of the approximation may vary dramatically depending on {@code numThresholds}. The + * {@code thresholds} parameter can be used to manually specify thresholds which split the * predictions more evenly. * - *

For best results, predictions should be distributed approximately uniformly in + *

For best results, {@code predictions} should be distributed approximately uniformly in * the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor - * if this is not the case. Setting summationMethod to minoring or - * majoring can help quantify the error in the approximation by providing lower or upper + * if this is not the case. Setting {@code summationMethod} to {@code minoring} or {@code + * majoring} can help quantify the error in the approximation by providing lower or upper * bound estimate of the AUC. * *

Usage:
@@ -155,8 +155,8 @@ public class AUC extends Metric { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link #DEFAULT_NUM_THRESHOLDS} for the numThresholds, {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for - * thresholds, false for multiLabel, and null for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed @@ -180,11 +180,11 @@ public AUC(Ops tf, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NUM_THRESHOLDS} for the * numThresholds, {@link AUCCurve#ROC} for the curve type, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, null for thresholds, - * false for multiLabel, and null for labelWeights. + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the confusion matrix variables. @@ -206,8 +206,8 @@ public AUC(Ops tf, String name, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, null for thresholds, false for multiLabel, and - * null for labelWeights. + * summation method, {@code null} for thresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -233,8 +233,8 @@ public AUC(Ops tf, int numThresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the - * summation method, null for numThresholds, false for multiLabel, and - * null for labelWeights. + * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -259,11 +259,11 @@ public AUC(Ops tf, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric. using {@link AUCCurve#ROC} for the curve type, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for - * thresholds, false for multiLabel, and null for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. * @param seed the seed for random number generation. An initializer created with a given seed @@ -285,13 +285,13 @@ public AUC(Ops tf, String name, int numThresholds, long seed, Class type) { } /** - * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link * AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the summation - * method, {@link #DEFAULT_NUM_THRESHOLDS} num thresholds, false for multiLabel, and - * null for labelWeights. + * method, {@link #DEFAULT_NUM_THRESHOLDS} num thresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. * @param seed the seed for random number generation. An initializer created with a given seed @@ -314,11 +314,11 @@ public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { /** * Creates an AUC (Area under the curve) metric using {@link AUCSummationMethod#INTERPOLATION} for - * the summation method, null for thresholds, false for multiLabel, and - * null for labelWeights. + * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -342,12 +342,12 @@ public AUC(Ops tf, String name, int numThresholds, AUCCurve curve, long seed, Cl } /** - * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link + * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link * AUCSummationMethod#INTERPOLATION} for the summation method, {@link #DEFAULT_NUM_THRESHOLDS} num - * thresholds, false for multiLabel, and null for labelWeights. + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -372,8 +372,8 @@ public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, C /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * {@link AUCSummationMethod#INTERPOLATION} for the summation method, null for - * thresholds, false for multiLabel, and null for labelWeights. + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values @@ -399,9 +399,9 @@ public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) } /** - * Creates an AUC (Area under the curve) metric using null for numThresholds, {@link - * AUCSummationMethod#INTERPOLATION} for the summation method, false for multiLabel, - * and null for labelWeights. + * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, + * and {@code null} for labelWeights. * * @param tf The TensorFlow Ops * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, @@ -428,7 +428,7 @@ public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) /** * Creates an AUC (Area under the curve) metric. using {@link #DEFAULT_NAME} for the metric name,, - * null for thresholds, false for multiLabel, and null for + * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for * labelWeights. * * @param tf The TensorFlow Ops @@ -453,7 +453,7 @@ public AUC( /** * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, - * null for numThresholds, false for multiLabel, and null + * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} * for labelWeights. * * @param tf The TensorFlow Ops @@ -487,11 +487,11 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using null for thresholds, - * false for multiLabel, and null for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code + * false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values * must be > 1. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -513,11 +513,11 @@ public AUC( } /** - * Creates an AUC (Area under the curve) metric. using null for the numThresholds, - * false for multiLabel, and null for labelWeights. + * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. * * @param tf The TensorFlow Ops - * @param name the name of the metric, if null defaults to {@link #DEFAULT_NAME} + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link @@ -560,15 +560,15 @@ public AUC( * @param summationMethod Specifies the Riemann summation method used * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, * the numThresholds parameter is ignored. Values should be in [0, 1]. This method - * automatically brackets the provided thresholds with a (-{@link #EPSILON}) + * automatically brackets the provided {@code thresholds} with a (-{@link #EPSILON}) * below and a (1 + {@link #EPSILON}) above. * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein * AUC is computed separately for each label and then averaged across labels, or (when false) * if the data should be flattened into a single label before AUC computation. In the latter * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an - * individual data point. Should be set to false for multi-class data. - * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When - * multiLabel is true, the weights are applied to the individual label AUCs when they + * individual data point. Should be set to {@code false} for multi-class data. + * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When {@code + * multiLabel} is true, the weights are applied to the individual label AUCs when they * are averaged to produce the multi-label AUC. When it's false, they are used to weight the * individual label predictions in computing the confusion matrix on the flattened data. * @param seed the seed for random number generation. An initializer created with a given seed @@ -715,9 +715,9 @@ private Map> build(Shape shape) { * * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be - * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} != null - * , then Cx must be a single dimension. - * @param predictions the predictions shape (N, Cx, P1?). Will be cast to T. + * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} {@code != null + * }, then Cx must be a single dimension. + * @param predictions the predictions shape (N, Cx, P1?). Will be cast to {@code T}. * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to * {@code }. * @return a List of Operations to update the metric state diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java index 30787a9889b..516d6c91ba6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -29,11 +29,11 @@ * Metric that calculates how often predictions equals labels. * *

This metric creates two local variables, total and count that are used to compute the - * frequency with which predictions matches labels. This frequency is + * frequency with which {@code predictions} matches {@code labels}. This frequency is * ultimately returned as binary accuracy: an idempotent operation that simply divides total by * count. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask + *

If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask * values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java index 4f9a267d633..0e41699e165 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -26,11 +26,11 @@ * Metric that calculates how often predictions matches binary labels. * *

This metric creates two local variables, total and count that are used to compute the - * frequency with which predictions matches labels. This frequency is + * frequency with which {@code predictions} matches {@code labels}. This frequency is * ultimately returned as binary accuracy: an idempotent operation that simply divides total by * count. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of 0 to mask + *

If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask * values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java index 55c3dc800e1..dece2d1cd50 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -27,18 +27,18 @@ /** * Metric that calculates how often predictions matches one-hot labels. * - *

You can provide logits of classes as predictions, since argmax of - * logits and probabilities are same. + *

You can provide {@code logits} of classes as {@code predictions}, since argmax of + * {@code logits} and probabilities are same. * - *

This metric creates two local variables, total and count that are - * used to compute the frequency with which predictions matches labels. + *

This metric creates two local variables, {@code total} and {@code count} that are + * used to compute the frequency with which {@code predictions} matches {@code labels}. * This frequency is ultimately returned as categorical accuracy: an idempotent operation that * simply divides total by count. * - *

predictions and labels should be passed in as vectors of + *

{@code predictions} and {@code labels} should be passed in as vectors of * probabilities, rather than as labels. If necessary, use {@link * org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand - * labels as a vector. + * {@code labels} as a vector. * *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. * @@ -77,7 +77,7 @@ public CategoricalAccuracy(Ops tf, String name, long seed, Class type) { * Computes the categorical crossentropy loss. * *

{@code predictions} and {@code labels} should be passed in as vectors of probabilities, - * rather than as labels. If necessary, use {@line Ops#oneHot} to expand {@code labels} as a + * rather than as labels. If necessary, use {@link Ops#oneHot} to expand {@code labels} as a * vector. * * @param labels One-hot ground truth values. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index a7e85ce5b02..58aa51f664c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -28,9 +28,9 @@ * labels. * *

This is the crossentropy metric class to be used when there are multiple label classes (2 or - * more). The labels should be given as a one_hot representation. eg., When labels values are - * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] - * . + * more). The labels should be given as a one_hot representation. eg., When labels values are {@code + * [2, 0, 1]}, the labels Operand contains = {@code [[0, 0, 1], [1, 0, 0], [0, 1, 0]] + * }. * * @param The data type for the metric result */ @@ -52,9 +52,9 @@ public class CategoricalCrossentropy extends MeanMetricWrappe * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 - * means that we will use a value of 0.1 for label 0 and 0.9 - * for label 1 + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} + * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 + * } for label {@code 1} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result @@ -73,13 +73,13 @@ public CategoricalCrossentropy( * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 - * means that we will use a value of 0.1 for label 0 and 0.9 - * for label 1 - * @param axis Int specifying the channels axis. axis={@link Losses#CHANNELS_LAST} - * corresponds to data format channels_last, and - * axis={@link Losses#CHANNELS_FIRST} corresponds to data format - * channels_first. + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} + * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 + * } for label {@code 1} + * @param axis Int specifying the channels axis. {@code axis={@link Losses#CHANNELS_LAST}} + * corresponds to data format {@code channels_last}, and {@code + * axis={@link Losses#CHANNELS_FIRST}} corresponds to data format {@code + * channels_first}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java index 39d33dda665..3db7fffc2e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false negatives. * - *

If sampleWeights is given, calculates the sum of the weights of false negatives. - * This metric creates one local variable, accumulator that is used to keep track of + *

If {@code sampleWeights} is given, calculates the sum of the weights of false negatives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of * the number of false negatives. * - *

If sampleWeights is null, weights default to 1. Use - * sampleWeights of 0 to mask values. + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. * * @param The data type for the metric result */ @@ -50,9 +50,9 @@ public FalseNegatives(Ops tf, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -66,9 +66,9 @@ public FalseNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -96,9 +96,9 @@ public FalseNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -113,9 +113,9 @@ public FalseNegatives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java index 3cf9fc0a5e9..551529b6179 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of false positives. * - *

If sampleWeights is given, calculates the sum of the weights of false positives. - * This metric creates one local variable, accumulator that is used to keep track of + *

If {@code sampleWeights} is given, calculates the sum of the weights of false positives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of * the number of false positives. * - *

If sampleWeights is null, weights default to 1. Use - * sampleWeights of 0 to mask values. + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. * * @param The data type for the metric result */ @@ -50,9 +50,9 @@ public FalsePositives(Ops tf, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -66,9 +66,9 @@ public FalsePositives(Ops tf, float threshold, long seed, Class type) { * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -96,9 +96,9 @@ public FalsePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -113,9 +113,9 @@ public FalsePositives(Ops tf, String name, float threshold, long seed, Class * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 70c2e6db8f6..03c31b2bab8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -35,11 +35,11 @@ * *

Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, * which first computes the IOU for each semantic class and then computes the average over classes. - * IOU is defined as follows: IOU = true_positive - * / (true_positive + false_positive + false_negative). The predictions are accumulated in a + * IOU is defined as follows: {@code IOU = true_positive + * / (true_positive + false_positive + false_negative)}. The predictions are accumulated in a * confusion matrix, weighted by sample_weight and the metric is then calculated from it. * - *

If sampleWeight is null, weights default to 1. Use sample_weight of + *

If {@code sampleWeight} is {@code null}, weights default to 1. Use sample_weight of * 0 to mask values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java index ac25183c0e5..acf28f5b2cc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -28,12 +28,12 @@ /** * Computes the mean relative error by normalizing with the given values. * - *

This metric creates two local variables, total and count that are - * used to compute the mean relative error. This is weighted by sampleWeight, and it is + *

This metric creates two local variables, {@code total} and {@code count} that are + * used to compute the mean relative error. This is weighted by {@code sampleWeight}, and it is * ultimately returned as mean relative error: an idempotent operation that simply divides total by * count. * - *

If {@code sampleWeight} is null, weights default to 1. Use {@code sampleWeight} + *

If {@code sampleWeight} is {@code null}, weights default to 1. Use {@code sampleWeight} * of 0 to mask values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index 5784cf46385..c56c53addf0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -36,22 +36,22 @@ /** * Computes the precision of the predictions with respect to the labels. * - *

The metric creates two local variables, truePositives and falsePositives - * that are used to compute the precision. This value is ultimately returned as precision, - * an idempotent operation that simply divides truePositives by the sum of - * truePositives and falsePositives. + *

The metric creates two local variables, {@code truePositives} and {@code falsePositives + * } that are used to compute the precision. This value is ultimately returned as precision, + * an idempotent operation that simply divides {@code truePositives} by the sum of {@code + * truePositives} and {@code falsePositives}. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of * 0 to mask values. * - *

If topK is set, the metric calculates precision as how often on average a class + *

If {@code topK} is set, the metric calculates precision as how often on average a class * among the top-k classes with the highest predicted values of a batch entry is correct and can be * found in the label for that entry. * - *

If classId is specified, the metric calculates precision by considering only the - * entries in the batch for which classId is above the thresholds and/or - * in the top-k highest predictions, and computing the fraction of them for which classId - * is indeed a correct label. + *

If {@code classId} is specified, the metric calculates precision by considering only the + * entries in the batch for which {@code classId} is above the {@code thresholds} and/or + * in the top-k highest predictions, and computing the fraction of them for which {@code classId + * } is indeed a correct label. * * @param The data type for the metric result */ @@ -103,7 +103,7 @@ public Precision(Ops tf, String name, long seed, Class type) { * values. * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -120,7 +120,7 @@ public Precision(Ops tf, float threshold, long seed, Class type) { * values. * * @param tf the TensorFlow Ops - * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -138,7 +138,7 @@ public Precision(Ops tf, float[] thresholds, long seed, Class type) { * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -156,7 +156,7 @@ public Precision(Ops tf, String name, float threshold, long seed, Class type) * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -172,7 +172,7 @@ public Precision(Ops tf, String name, float[] thresholds, long seed, Class ty * Creates a Precision Metric with a name of {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops - * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -193,7 +193,7 @@ public Precision( * Creates a Precision Metric with a name of {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops - * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -216,7 +216,7 @@ public Precision( * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param threshold Optional threshold value in the range [0, 1]. A threshold is + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. @@ -245,7 +245,7 @@ public Precision( * @param tf the TensorFlow Ops * @param name name of the metric instance. If null, name defaults to {@link * Class#getSimpleName()}. - * @param thresholds Optional threshold values in the range [0, 1]. A threshold is + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is * compared with prediction values to determine the truth value of predictions (i.e., above * the threshold is true, below is false). One metric value is generated for each threshold * value. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 4205f761e4b..483b2523d74 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -29,7 +29,7 @@ * falseNegatives that are used to compute the precision at the given recall. The threshold for the * given recall value is computed and used to evaluate the corresponding precision. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of + *

If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of * 0 to mask values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java index ca5968d4f9d..3886ec050b0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -36,20 +36,20 @@ /** * Computes the recall of the predictions with respect to the labels. * - *

This metric creates two local variables, truePositives and falseNegatives - * , that are used to compute the recall. This value is ultimately returned as recall, an - * idempotent operation that simply divides truePositives by the sum of - * truePositives and falseNegatives. + *

This metric creates two local variables, {@code truePositives} and {@code falseNegatives + * }, that are used to compute the recall. This value is ultimately returned as recall, an + * idempotent operation that simply divides {@code truePositives} by the sum of {@code + * truePositives} and {@code falseNegatives}. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of * 0 to mask values. * - *

If topK is set, the metric calculates recall as how often on average a class + *

If {@code topK} is set, the metric calculates recall as how often on average a class * among the labels of a batch entry is in the top-k predictions. * - *

If classId is specified, the metric calculates recall by considering only the - * entries in the batch for which classId is in the label, and computing the fraction - * of them for which classId is above the threshold and/or in the top-k predictions. + *

If {@code classId} is specified, the metric calculates recall by considering only the + * entries in the batch for which {@code classId} is in the label, and computing the fraction + * of them for which {@code classId} is above the threshold and/or in the top-k predictions. * * @param The data type for the metric result */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index fb6890d1e01..72eaedb9c4d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -34,7 +34,7 @@ * falseNegatives that are used to compute the recall at the given precision. The threshold for the * given precision value is computed and used to evaluate the corresponding recall. * - *

If sampleWeights is null, weights default to 1. Use sampleWeights of + *

If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of * 0 to mask values. * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java index 721b95487c7..3886428425b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -27,7 +27,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes root mean squared error metric between labels and predictions + * Computes root mean squared error metric between {@code labels} and {@code predictions} * . * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 2c7420a5518..7cf5f38d9a4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -25,18 +25,18 @@ /** * Computes best sensitivity where sensitivity is >= specified value. * - *

Sensitivity measures the proportion of actual positives that are correctly - * identified as such (tp / (tp + fn)). + *

{@code Sensitivity} measures the proportion of actual positives that are correctly + * identified as such {@code (tp / (tp + fn))}. * - *

Specificity measures the proportion of actual negatives that are correctly - * identified as such (tn / (tn + fp)). + *

{@code Specificity} measures the proportion of actual negatives that are correctly + * identified as such {@code (tn / (tn + fp))}. * - *

This metric creates four local variables, truePositives, trueNegatives - * , falsePositives and falseNegatives that are used to compute the + *

This metric creates four local variables, {@code truePositives}, {@code trueNegatives + * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the * sensitivity at the given specificity. The threshold for the given specificity value is computed * and used to evaluate the corresponding sensitivity. * - *

If sampleWeights is null, weights default to 1. Use sample_weight of + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of * 0 to mask values. * * @see Additional information diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java index 6dfdab48578..5294f798044 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -31,11 +31,11 @@ /** * Calculates how often predictions matches integer labels. * - *

You can provide logits of classes as predictions, since argmax of logits and + *

You can provide logits of classes as {@code predictions}, since argmax of logits and * probabilities are same. * *

This metric creates two local variables, `total` and `count` that are used to compute the - * frequency with which predictions matches labels. This frequency is + * frequency with which {@code predictions} matches {@code labels}. This frequency is * ultimately returned as `sparse categorical accuracy`: an idempotent operation that simply divides * `total` by `count`. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index d0b797690bd..981171f2221 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -23,19 +23,19 @@ import static org.tensorflow.framework.utils.CastHelper.cast; /** - * Computes best specificity where sensitivity is >= specified value. Sensitivity - * measures the proportion of actual positives that are correctly identified as such - * (tp / (tp + fn)). + * Computes best specificity where sensitivity is >= specified value. {@code Sensitivity} + * measures the proportion of actual positives that are correctly identified as such {@code + * (tp / (tp + fn))}. * - *

Specificity measures the proportion of actual negatives that are correctly - * identified as such (tn / (tn + fp)). + *

{@code Specificity} measures the proportion of actual negatives that are correctly + * identified as such {@code (tn / (tn + fp))}. * - *

This metric creates four local variables, truePositives, trueNegatives - * , falsePositives and falseNegatives that are used to compute the + *

This metric creates four local variables, {@code truePositives}, {@code trueNegatives + * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the * specificity at the given sensitivity. The threshold for the given sensitivity value is computed * and used to evaluate the corresponding specificity. * - *

If sampleWeights is null, weights default to 1. Use sample_weight of + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of * 0 to mask values. * * @see Additional information diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java index a3241221b66..637ca6cdd05 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -21,10 +21,10 @@ /** * Computes the (weighted) sum of the given values. * - *

For example, if values is [1, 3, 5, 7] then the sum is 16. If the - * weights were specified as [1, 1, 0, 0], then the sum would be 4. + *

For example, if values is {@code [1, 3, 5, 7]} then the sum is {@code 16}. If the + * weights were specified as {@code [1, 1, 0, 0]}, then the sum would be {@code 4.} * - *

This metric creates one variable, total, that is used to compute the sum of + *

This metric creates one variable, {@code total}, that is used to compute the sum of * values. This is ultimately returned as sum. * *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java index 9c8d6403a6b..0146552433f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -34,7 +34,7 @@ public class TopKCategoricalAccuracy extends MeanMetricWrappe private final int k; /** - * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for k, Number of + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of * top elements to look at for computing accuracy. * * @param tf the TensorFlow Ops diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java index 91b6751588a..5c65f8c469f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true negatives. * - *

If sampleWeights is given, calculates the sum of the weights of true negatives. - * This metric creates one local variable, accumulator that is used to keep track of + *

If {@code sampleWeights} is given, calculates the sum of the weights of true negatives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of * the number of true negatives. * - *

If sampleWeights is null, weights default to 1. Use - * sampleWeights of 0 to mask values. + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. * * @param The data type for the metric result */ @@ -50,9 +50,9 @@ public TrueNegatives(Ops tf, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -66,9 +66,9 @@ public TrueNegatives(Ops tf, float threshold, long seed, Class type) { * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -96,9 +96,9 @@ public TrueNegatives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -113,9 +113,9 @@ public TrueNegatives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java index b67d381a62d..f0dd8c42de5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -22,12 +22,12 @@ /** * Metric that calculates the number of true positives. * - *

If sampleWeights is given, calculates the sum of the weights of true positives. - * This metric creates one local variable, accumulator that is used to keep track of + *

If {@code sampleWeights} is given, calculates the sum of the weights of true positives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of * the number of true positives. * - *

If sampleWeights is null, weights default to 1. Use - * sampleWeights of 0 to mask values. + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. * * @param The data type for the metric result */ @@ -50,9 +50,9 @@ public TruePositives(Ops tf, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -66,9 +66,9 @@ public TruePositives(Ops tf, float threshold, long seed, Class type) { * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name * * @param tf the TensorFlow Ops - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -96,9 +96,9 @@ public TruePositives(Ops tf, String name, long seed, Class type) { * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param threshold a threshold value in the range [0, 1]. A threshold is compared + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -113,9 +113,9 @@ public TruePositives(Ops tf, String name, float threshold, long seed, Class t * * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used - * @param thresholds threshold values in the range [0, 1]. A threshold is compared + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared * with prediction values to determine the truth value of predictions (i.e., above the - * threshold is true, below is false). One metric value is generated + * threshold is {@code true}, below is {@code false}). One metric value is generated * for each threshold value * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java index 63ea35df7f2..88597cf85ec 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -67,9 +67,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param threshold a threshold value in [0, 1]. A threshold is compared with + * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is - * true, below is false). One metric value is generated for each + * {@code true}, below is {@code false}). One metric value is generated for each * threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -91,9 +91,9 @@ public ConfusionMatrixConditionCount( * @param tf the TensorFlow Ops * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used * @param confusionMatrixCond the confusion matrix condition to calculate - * @param thresholds threshold values in [0, 1]. A threshold is compared with + * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with * prediction values to determine the truth value of predictions (i.e., above the threshold is - * true, below is false). One metric value is generated for each + * {@code true}, below is {@code false}). One metric value is generated for each * threshold value. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 1fb3d3bb580..f89047e457d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -25,7 +25,7 @@ public interface LossMetric { /** - * Calculates the weighted loss between labels and predictions + * Calculates the weighted loss between {@code labels} and {@code predictions} * * @param labels the truth values or labels * @param predictions the predictions diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 9a532a0294f..37bdd5849ae 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -29,8 +29,8 @@ * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. * - *

The loss function calculates the loss between the labels and predictions - * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the + *

The loss function calculates the loss between the {@code labels} and {@code predictions + * } then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index f36aaa34d8f..54b2646a62b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -58,8 +58,8 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the sampleWeights can be broadcast to the same shape as values - * + * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values + * } * *

In losses and metrics, limited weight broadcasting is supported. Weights must be either * scalar, or the same rank as the target values, with each dimension either 1, or the same as the @@ -68,11 +68,11 @@ public class MetricsHelper { * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return Operation with control dependencies to ensure sampleWeight - * can be broadcast to values + * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} + * can be broadcast to {@code values} * @param the type of Operand - * @throws NotBroadcastableException If static checks determine sampleWeights has an - * incorrect shape that prohibit broadcasting to values + * @throws NotBroadcastableException If static checks determine {@code sampleWeights} has an + * incorrect shape that prohibit broadcasting to {@code values} */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -200,13 +200,13 @@ private static Operand canBroadcastDims( } /** - * Broadcast weights to the same shape as values. + * Broadcast {@code weights} to the same shape as {@code values}. * * @param tf the TensorFlow ops - * @param weights Operand whose shape is broadcastable to values. + * @param weights Operand whose shape is broadcastable to {@code values}. * @param values Operand of any shape * @param the type of Operands - * @return weights broadcast to values shape + * @return {@code weights} broadcast to {@code values} shape */ public static Operand broadcastWeights( Ops tf, Operand weights, Operand values) { @@ -291,13 +291,13 @@ public static List assertShapes( /** * Returns an op to update the given confusion matrix variables. * - *

For every pair of values in labels and predictions: + *

For every pair of values in {@code labels} and {@code predictions}: * *

-   * TRUE_POSITIVES:  labels == true and predictions > thresholds
-   * FALSE_POSITIVES: labels == true and predictions <= thresholds
-   * TRUE_NEGATIVES:  labels == false and predictions <= thresholds
-   * FALSE_NEGATIVE:  labels == false and predictions > thresholds
+   * TRUE_POSITIVES:  {@code labels} == true and {@code predictions} > thresholds
+   * FALSE_POSITIVES: {@code labels} == true and {@code predictions} <= thresholds
+   * TRUE_NEGATIVES:  {@code labels} == false and {@code predictions} <= thresholds
+   * FALSE_NEGATIVE:  {@code labels} == false and {@code predictions} > thresholds
    * 
* *

The results will be weighted and added together. When multiple thresholds are provided, we @@ -306,46 +306,45 @@ public static List assertShapes( *

For estimation of these metrics over a stream of data, the function creates an `update_op` * operation that updates the given variables. * - *

labels, predictions, and sampleWeight tensors are - * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. - * sampleWeight is then broadcast to the shape of predictions. + *

{@code labels}, {@code predictions}, and {@code sampleWeight} tensors are + * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code + * sampleWeight} is then broadcast to the shape of {@code predictions}. * * @param tf the TensorFlow Ops * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. If multiLabel, then the variable + * corresponding variables to update as values. If {@code multiLabel}, then the variable * shapes are (T, D), where T is the number of thresholds and D is the number of classes - * (after slicing by classIndex, if provided). If multiLabels, then + * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then * the variable shapes are (T). * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding initializer Operands to for variablesToUpdate. + * corresponding initializer Operands to for {@code variablesToUpdate}. * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra * dimension of size 1 that would be squeezed. * @param predictions the predictions shape (N, Cx, P1?) - * @param thresholds thresholds in the range [0, 1], or {@link #NEG_INF} is used when + * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when * topK is set * @param topK optional, indicates that only the top k predictions should be considered. Applied - * before possibly slicing by classIndex. + * before possibly slicing by {@code classIndex}. * @param classIndex optional, limits the prediction and labels to the specified class. This is an * integer index into the first dimension of Cx. - * @param sampleWeight optional Tensor that is aligned with labels and predictions as + * @param sampleWeight optional {@code Tensor} that is aligned with labels and predictions as * explained above. Use weights of 0 to mask values. * @param multiLabel indicates whether multidimensional prediction/labels should be treated as - * multilabel responses, or flattened into a single label. When true, the values of - * variablesToUpdate must have a second dimension equal to the number of labels and + * multilabel responses, or flattened into a single label. When true, the values of {@code + * variablesToUpdate} must have a second dimension equal to the number of labels and * predictions per example, and those tensors must not be RaggedTensors. * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES * without explicit multilabel handling (i.e. when the data is to be flattened). Must have - * shape (Dx), which is the same as (Cx) referenced above, except that if classIndex - * is provided, then the final dimension of Dx is 1. These weights will be broadcast - * across the 0th dimension (the examples dimension) of predictions. May be null. - * Must be null if multiLabel. + * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex + * } is provided, then the final dimension of Dx is 1. These weights will be broadcast + * across the 0th dimension (the examples dimension) of {@code predictions}. May be null. + * Must be null if {@code multiLabel}. * @param the data type for the variables - * @throws IllegalArgumentException If predictions and labels have - * mismatched shapes, or if sampleWeight is not nulland its shape - * doesn't match predictions, or if multiLabel && labelWeights != null - * . + * @throws IllegalArgumentException If {@code predictions} and {@code labels} have + * mismatched shapes, or if {@code sampleWeight} is not null and its shape + * doesn't match {@code predictions}, or if {@code multiLabel && labelWeights != null}.. * @return an op to update the given confusion matrix variables. */ @SuppressWarnings({"unchecked", "rawtypes"}) @@ -689,8 +688,8 @@ private static Operand filterTopK(Ops tf, Operand x, i // alias for mean /** - * Calculate the mean of the operand, along all axes and keepDims is false - * + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false + * } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -702,8 +701,8 @@ public static Operand mean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is + * {@code false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -721,12 +720,12 @@ public static Operand mean( * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. * @param the type of the operand - * @return the mean of elements of x. + * @return the mean of elements of {@code x}. */ public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); @@ -738,12 +737,12 @@ public static Operand mean(Ops tf, Operand x, boolean * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. * @param the data type of the Operand - * @return the mean of elements of x. + * @return the mean of elements of {@code x}. */ public static Operand mean( Ops tf, Operand x, Operand axes, boolean keepDims) { @@ -778,16 +777,16 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( * *

For example: * - *

+   * 
{@code
    *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
    *          [[0 0 0 0 0]
    *           [0 0 1 0 0]
    *           [0 0 1 0 0]
    *           [0 0 0 0 0]
    *           [0 0 0 0 1]]
-   * 
+ * }
* - * Note that the possible labels are assumed to be {@copde [0, 1, 2, 3,4]}, resulting in a 5x5 + * Note that the possible labels are assumed to be {@code [0, 1, 2, 3,4]}, resulting in a 5x5 * confusion matrix. * * @param tf the TensorFlow Ops @@ -798,12 +797,12 @@ LossTuple raggedAssertCompatibleAndGetFlatValues( * @param weights optional weights to be applied to the confusion matrix * @param type Data type of the confusion matrix. * @param the type of Operands - * @return A Operand of type type with shape [n, n] - * representing the confusion matrix, where n is the number of possible labels in + * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} + * representing the confusion matrix, where {@code n} is the number of possible labels in * the classification task. - * @throws IllegalArgumentException If both predictions and labels do - * not have compatible shapes, or if weights is notnull and its - * shape is not compatible with predictions. + * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do + * not have compatible shapes, or if {@code weights} is not{@code null} and its + * shape is not compatible with {@code predictions}. */ // TODO should this be moved to FramnworkOps under math. public static Operand confusionMatrix( @@ -879,8 +878,8 @@ public static Operand confusionMatrix( } /** - * Calculate the mean of the operand, along all axes and keepDims is false - * + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false + * } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -891,8 +890,8 @@ public static Operand booleanMean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is + * {@code false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -909,11 +908,11 @@ public static Operand booleanMean( * * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. - * @return the mean of elements of x containing floating point numbers + * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); @@ -925,11 +924,11 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. - * @return the mean of elements of x containing floating point numbers + * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean( Ops tf, Operand x, Operand axes, boolean keepDims) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 3b54ad2e08d..b96d2dfa1d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -106,8 +106,8 @@ public Op resetStates() { } /** - * Updates the metric variables based on the inputs. At least one input arg required for {@}code - * values}, an optional additional input for the sampleWeights + * Updates the metric variables based on the inputs. At least one input arg required for {@code + * values}, an optional additional input for the {@code sampleWeights} * * @param values the inputs to be passed to update state, this may not be null * @param sampleWeights sample weights to be applied to values, will default to 1 if null. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 467dea19b57..68157632557 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -26,16 +26,16 @@ public class SetsOps { /** - * Computes set difference of elements in last dimension of a and b with - * aMinusB set to true. + * Computes set difference of elements in last dimension of {@code a} and {@code b} with + * {@code aMinusB} set to true. * - *

All but the last dimension of a and b must match + *

All but the last dimension of {@code a} and {@code b} must match * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -44,16 +44,16 @@ public static Operand difference(Ops tf, Operand a, Op } /** - * Computes set difference of elements in last dimension of a and b. + * Computes set difference of elements in last dimension of {@code a} and {@code b}. * - *

All but the last dimension of a and b must match + *

All but the last dimension of {@code a} and {@code b} must match * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param aMinusB whether to subtract b from a, vs vice versa. * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -63,13 +63,13 @@ public static Operand difference( } /** - * Computes set union of elements in last dimension of a and b. + * Computes set union of elements in last dimension of {@code a} and {@code b}. * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -78,13 +78,13 @@ public static Operand union(Ops tf, Operand a, Operand } /** - * Computes set intersection of elements in last dimension of a and b. + * Computes set intersection of elements in last dimension of {@code a} and {@code b}. * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -93,14 +93,14 @@ public static Operand intersection(Ops tf, Operand a, } /** - * Compute set operation of elements in last dimension of a and b. + * Compute set operation of elements in last dimension of {@code a} and {@code b}. * * @param tf the TensorFlow Ops * @param a The first set operation operand * @param b The other et operation operand * @param setOperation The set operation to perform, {@link Operation}. * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the same. Elements along the last dimension contain the results of the set * operation. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 36792b8ea7a..fc7f1abbd89 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -151,17 +151,17 @@ private static Operand hasValidDims( /** * Broadcast `weights` to the same shape as `values`. * - *

This returns a version of weights following the same broadcast rules as + *

This returns a version of {@code weights} following the same broadcast rules as {@code * mul(weights, - * values), but limited to the weights shapes allowed by assertBroadcastable - * When computing a weighted average, use this function to broadcast weights before - * summing them; e.g., reduceSum(w * v) / reduceSum(_broadcast_weights(w, v)). + * values)}, but limited to the weights shapes allowed by {@code assertBroadcastable} + * When computing a weighted average, use this function to broadcast {@code weights} before + * summing them; e.g., {@code reduceSum(w * v) / reduceSum(_broadcast_weights(w, v))}. * * @param tf the TensorFlow ops * @param weights `Tensor` whose shape is able to be broadcast to `values` * @param values Tensor` of any shape * @param the type of Operand - * @return weights broadcast to values shape + * @return {@code weights} broadcast to {@code values} shape */ public static Operand broadcastWeights( Ops tf, Operand weights, Operand values) { From b91cabff91da7cd370423c70b33e396ad2d32c9e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 12 Apr 2021 12:55:24 -0400 Subject: [PATCH 87/97] Use zero operand for initializing falsePositives --- .../main/java/org/tensorflow/framework/metrics/Precision.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index c56c53addf0..8b0de5042bf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -290,7 +290,7 @@ private void init() { if (this.falsePositives == null) { this.falsePositives = tf.withName(falsePositivesName) - .variable(zeros.call(tf.constant(Shape.of(thresholds.length)), type)); + .variable(zero); initializers.add(tf.assign(falsePositives, zero)); } } From 56b0300d816943e46bddba7a394e7be21a6cca3a Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 12 Apr 2021 12:58:10 -0400 Subject: [PATCH 88/97] Usefix javadoc add logic to accept a 1 item weightsShapeStatic. change is_scalar to isScalar Fix logic in hasValidDims to mactch Pyhton implementation. --- .../metrics/impl/WeightsBroadcastOps.java | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index fc7f1abbd89..aa51da729a8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -18,6 +18,7 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.NoOp; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; @@ -28,21 +29,27 @@ import static org.tensorflow.framework.utils.CastHelper.cast; +/** + * Weight broadcasting operations. + * + *

In {@link org.tensorflow.framework.losses} and `{@link org.tensorflow.framework.metrics}, we support limited weight broadcasting. This file includes + * operations for those broadcasting rules. + */ public class WeightsBroadcastOps { private static final String ASSERT_BROADCASTABLE_ERROR_PREFIX = "weights can not be broadcast to values."; /** - * Asserts that `weights` can be broadcast to `values` + * Asserts that {@code weights} can be broadcast to {@code values} * * @param tf the TensorFlow Ops - * @param weights `Tensor` of weights. - * @param values `Tensor` of values to which weights are applied. - * @return `Operation` raising `InvalidArgumentError` if `weights` has incorrect shape. `no_op` if - * static checks determine `weights` has correct shape. + * @param weights the weights Operand + * @param values Operand of values to which weights are applied. + * @return {@code Operation} raising a tensorflow InvalidArgumentError if {@code weights} has incorrect shape. {@link NoOp} if + * static checks determine {@code weights} has correct shape. * @param the type of weights and values - * @throws IllegalArgumentException If static checks determine `weights` has incorrect shape. + * @throws IllegalArgumentException If static checks determine {@code weights} has incorrect shape. */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -75,7 +82,7 @@ public static Op assertBroadcastable( } for (int i = 0; i < valuesRankStatic; i++) { - if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + if (weightsShapeStatic.size(i) != 1 && valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { throw new IllegalArgumentException( String.format( "%s Mismatch at dim %s. values.shape=%s weights.shape=%s.", @@ -90,7 +97,7 @@ public static Op assertBroadcastable( .noOp(); } // Dynamic checks. - Operand is_scalar = tf.math.equal(weightsRank, tf.constant(0)); + Operand isScalar = tf.math.equal(weightsRank, tf.constant(0)); List> data = Arrays.asList( tf.constant(ASSERT_BROADCASTABLE_ERROR_PREFIX), @@ -98,13 +105,13 @@ public static Op assertBroadcastable( weightsShape, tf.constant("values.shape="), valuesShape, - tf.constant("is_scalar="), - is_scalar); + tf.constant("isScalar="), + isScalar); Operand isValidShape = tf.select( - is_scalar, - is_scalar, + isScalar, + isScalar, hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); return tf.assertThat(isValidShape, data); @@ -134,7 +141,7 @@ private static Operand hasValidNonscalarShape( } /** - * Checks that each dimension of the two shapes are the same + * Checks that each dimension of the two shapes are the same size, or that the weight dimension size is 1. * * @param tf the TensorFlow Ops * @param weightsShape the shape of the weights @@ -144,12 +151,18 @@ private static Operand hasValidNonscalarShape( private static Operand hasValidDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("hasInvalidDims"); - Operand diff = tf.reduceSum(tf.math.sub(weightsShape, valuesShape), tf.constant(0)); - return tf.math.equal(tf.constant(0), diff); + + Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); + Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); + Operand weightsShape2d = tf.expandDims(weightsShape, tf.constant(-1)); + + Operand invalidDims = SetsOps.difference(tf, weightsShape2d, validDims); + Operand numInvalidDims = tf.size(invalidDims, TInt32.class); + return tf.math.equal(tf.constant(0), numInvalidDims); } /** - * Broadcast `weights` to the same shape as `values`. + * Broadcast {@code weights} to the same shape as {@code values}. * *

This returns a version of {@code weights} following the same broadcast rules as {@code * mul(weights, @@ -158,7 +171,7 @@ private static Operand hasValidDims( * summing them; e.g., {@code reduceSum(w * v) / reduceSum(_broadcast_weights(w, v))}. * * @param tf the TensorFlow ops - * @param weights `Tensor` whose shape is able to be broadcast to `values` + * @param weights Operand whose shape is able to be broadcast to {@code values} * @param values Tensor` of any shape * @param the type of Operand * @return {@code weights} broadcast to {@code values} shape From f1203aa9424a0a1c014838dc97cf180542e5d617 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 12 Apr 2021 13:00:05 -0400 Subject: [PATCH 89/97] fix javadoc add checks for sampleWeights rank matching lables rank add checks on labels and predicitons to make sure they have the same number of elements. --- .../tensorflow/framework/metrics/MeanIoU.java | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 03c31b2bab8..22baab3d6cb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -35,12 +35,12 @@ * *

Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, * which first computes the IOU for each semantic class and then computes the average over classes. - * IOU is defined as follows: {@code IOU = true_positive - * / (true_positive + false_positive + false_negative)}. The predictions are accumulated in a - * confusion matrix, weighted by sample_weight and the metric is then calculated from it. + * IOU is defined as follows: {@code IOU = true_positive / (true_positive + false_positive + + * false_negative)}. The predictions are accumulated in a confusion matrix, weighted by + * sample_weight and the metric is then calculated from it. * - *

If {@code sampleWeight} is {@code null}, weights default to 1. Use sample_weight of - * 0 to mask values. + *

If {@code sampleWeight} is {@code null}, weights default to 1. Use sample_weight of 0 to mask + * values. * * @param The data type for the metric result */ @@ -124,12 +124,35 @@ public Assign getInitializer() { * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either * 0, or the same rank as labels, and must be broadcastable to labels. * @return the Operands that updates totalConfusionMatrix variable + * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} labels rank, + * and if the predictions size is not equal to the labels size */ @Override public List updateStateList( Operand labels, Operand predictions, Operand sampleWeights) { + if (sampleWeights != null) { + long weightsRank = sampleWeights.shape().numDimensions(); + long labelsRank = labels.shape().numDimensions(); + if (weightsRank != 0 + && weightsRank != Shape.UNKNOWN_SIZE + && labelsRank != Shape.UNKNOWN_SIZE + && weightsRank != labelsRank) { + throw new IllegalArgumentException( + String.format( + "Weights must either have rank 0, or the same rank as labels, weights rank = %d, labels rank = %d", + weightsRank, labelsRank)); + } + } + long labelsSize = labels.shape().size(); + long predictionsSize = predictions.shape().size(); + if (labelsSize != predictionsSize) { + throw new IllegalArgumentException( + String.format( + "labels and predictions must have the same size, labels size = %d, predictions size = %d", + labelsSize, predictionsSize)); + } Operand tLabels = cast(getTF(), labels, type); if (tLabels.shape().numDimensions() > 1) { From dedaede7f67068f741c87da1d7666034ce859a80 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 12 Apr 2021 13:01:14 -0400 Subject: [PATCH 90/97] add axis to tf.gather, and tf.squeeze results. --- .../tensorflow/framework/metrics/impl/MetricsHelper.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 54b2646a62b..bd260fc7ebd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -25,6 +25,7 @@ import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.OneHot; import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Squeeze; import org.tensorflow.op.core.Stack; import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; @@ -438,8 +439,12 @@ tPredictions, cast(tf, tf.constant(0), tPredictions.type())), if (classIndex != null) { // Slice to new shapes (N, Dx) - tLabels = tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(1)); - tPredictions = tf.gather(tPredictions, tf.constant(new int[] {classIndex}), tf.constant(1)); + tLabels = tf.squeeze(tf.gather(tLabels, + tf.constant(new int[] {classIndex}), tf.constant(-1)), + Squeeze.axis(Collections.singletonList(1L))); + tPredictions = tf.squeeze(tf.gather(tPredictions, + tf.constant(new int[] {classIndex}), tf.constant(-1)), + Squeeze.axis(Collections.singletonList(1L))); } org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); From 09740382e8f30cd2c9dd6609f3a265dc8ceb7e99 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 16 Apr 2021 11:00:57 -0400 Subject: [PATCH 91/97] Remove unnecessary toStrings() on Shapes. --- .../tensorflow/framework/metrics/impl/MetricsHelper.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index bd260fc7ebd..40336233d21 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -104,8 +104,8 @@ public static Op assertBroadcastable( ASSERT_BROADCAST_ERROR_PREFIX, valuesRankStatic, weightsRankStatic, - valuesShapeStatic.toString(), - weightsShapeStatic.toString())); + valuesShapeStatic, + weightsShapeStatic)); } for (int i = 0; i < valuesRankStatic; i++) { @@ -116,8 +116,8 @@ public static Op assertBroadcastable( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, i, - valuesShapeStatic.toString(), - weightsShapeStatic.toString())); + valuesShapeStatic, + weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") From a3deb5c518c1bd1a4f7789869d7e7e1b8586d8a4 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 18 Apr 2021 18:31:41 -0400 Subject: [PATCH 92/97] remove value from SensititySpecificityBase, fix sub-class CTORs. --- .../framework/metrics/PrecisionAtRecall.java | 12 ++++++------ .../framework/metrics/RecallAtPrecision.java | 8 ++++---- .../metrics/SensitivityAtSpecificity.java | 10 +++++----- .../metrics/SpecificityAtSensitivity.java | 10 +++++----- .../metrics/impl/SensitivitySpecificityBase.java | 15 +++------------ 5 files changed, 23 insertions(+), 32 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java index 483b2523d74..5f5f9b47a10 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -104,7 +104,7 @@ public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Cla */ public PrecisionAtRecall( Ops tf, String name, float recall, int numThresholds, long seed, Class type) { - super(tf, name, recall, numThresholds, seed, type); + super(tf, name, numThresholds, seed, type); if (recall < 0f || recall > 1f) throw new IllegalArgumentException("recall must be in the range [0, 1]."); this.recall = recall; @@ -115,14 +115,14 @@ public PrecisionAtRecall( public Operand result() { Ops tf = getTF(); - Operand recall = - tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); - Operand sub = tf.math.sub(recall, cast(tf, tf.constant(value), getType())); + Operand div = + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + Operand sub = tf.math.sub(div, cast(tf, tf.constant(recall), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); - Operand trueSlice = tf.slice(this.truePositives, minIndex, tf.constant(new int[] {1})); - Operand falseSlice = tf.slice(this.falsePositives, minIndex, tf.constant(new int[] {1})); + Operand trueSlice = tf.slice(truePositives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(falsePositives, minIndex, tf.constant(new int[] {1})); return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java index 72eaedb9c4d..a3fc2f77b7f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -109,7 +109,7 @@ public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, */ public RecallAtPrecision( Ops tf, String name, float precision, int numThresholds, long seed, Class type) { - super(tf, name, precision, numThresholds, seed, type); + super(tf, name, numThresholds, seed, type); if (precision < 0f || precision > 1f) throw new IllegalArgumentException("recall must be in the range [0, 1]."); this.precision = precision; @@ -121,11 +121,11 @@ public Operand result() { Ops tf = getTF(); Operand precisions = - tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falsePositives)); + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); Operand recalls = - tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); Operand isFeasible = - tf.math.greaterEqual(precisions, cast(tf, tf.constant(this.value), getType())); + tf.math.greaterEqual(precisions, cast(tf, tf.constant(precision), getType())); Where feasible = tf.where(isFeasible); Operand feasibleExists = tf.math.greater(tf.size(feasible), tf.constant(0)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java index 7cf5f38d9a4..29c0504b823 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -115,7 +115,7 @@ public SensitivityAtSpecificity( */ public SensitivityAtSpecificity( Ops tf, String name, float specificity, int numThresholds, long seed, Class type) { - super(tf, name, specificity, numThresholds, seed, type); + super(tf, name, numThresholds, seed, type); if (specificity < 0f || specificity > 1f) throw new IllegalArgumentException("specificity must be in the range [0, 1]."); this.specificity = specificity; @@ -126,13 +126,13 @@ public SensitivityAtSpecificity( public Operand result() { Ops tf = getTF(); Operand specificities = - tf.math.divNoNan(this.trueNegatives, tf.math.add(this.trueNegatives, this.falsePositives)); - Operand sub = tf.math.sub(specificities, cast(tf, tf.constant(this.getValue()), getType())); + tf.math.divNoNan(trueNegatives, tf.math.add(trueNegatives, falsePositives)); + Operand sub = tf.math.sub(specificities, cast(tf, tf.constant(specificity), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); - Operand trueSlice = tf.slice(this.truePositives, minIndex, tf.constant(new int[] {1})); - Operand falseSlice = tf.slice(this.falseNegatives, minIndex, tf.constant(new int[] {1})); + Operand trueSlice = tf.slice(truePositives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(falseNegatives, minIndex, tf.constant(new int[] {1})); return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java index 981171f2221..2cb7e54eba0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -114,7 +114,7 @@ public SpecificityAtSensitivity( */ public SpecificityAtSensitivity( Ops tf, String name, float sensitivity, int numThresholds, long seed, Class type) { - super(tf, name, sensitivity, numThresholds, seed, type); + super(tf, name, numThresholds, seed, type); if (sensitivity < 0f || sensitivity > 1f) throw new IllegalArgumentException("sensitivity must be in the range [0, 1]."); this.sensitivity = sensitivity; @@ -127,13 +127,13 @@ public Operand result() { Ops tf = getTF(); Operand sensitivities = - tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); - Operand sub = tf.math.sub(sensitivities, cast(tf, tf.constant(this.getValue()), getType())); + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + Operand sub = tf.math.sub(sensitivities, cast(tf, tf.constant(sensitivity), getType())); Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); minIndex = tf.expandDims(minIndex, tf.constant(0)); - Operand trueSlice = tf.slice(this.trueNegatives, minIndex, tf.constant(new int[] {1})); - Operand falseSlice = tf.slice(this.falsePositives, minIndex, tf.constant(new int[] {1})); + Operand trueSlice = tf.slice(trueNegatives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(falsePositives, minIndex, tf.constant(new int[] {1})); return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java index 08b298294ac..60a6c1ea3df 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -32,7 +32,7 @@ public abstract class SensitivitySpecificityBase extends Metr public static final String TRUE_NEGATIVES = "TRUE_NEGATIVES"; public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; protected final int numThresholds; - protected final float value; + protected final float[] thresholds; private final String truePositivesName; private final String falsePositivesName; @@ -54,7 +54,6 @@ public abstract class SensitivitySpecificityBase extends Metr * * @param tf the TensorFlow Ops * @param name the name of the metric instance, if null then {@link Class#getSimpleName()} is used - * @param value A scalar value in range `[0, 1]` * @param numThresholds The number of thresholds to use for matching the given recall. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. @@ -62,7 +61,7 @@ public abstract class SensitivitySpecificityBase extends Metr * @throws IllegalArgumentException if numThresholds <= 0. */ protected SensitivitySpecificityBase( - Ops tf, String name, float value, int numThresholds, long seed, Class type) { + Ops tf, String name, int numThresholds, long seed, Class type) { super(tf, name, seed); if (numThresholds <= 0) throw new IllegalArgumentException("numThresholds must be > 0."); this.type = type; @@ -71,7 +70,6 @@ protected SensitivitySpecificityBase( this.trueNegativesName = this.getVariableName(TRUE_NEGATIVES); this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); - this.value = value; this.numThresholds = numThresholds; if (this.numThresholds == 1) { @@ -230,14 +228,7 @@ public int getNumThresholds() { return numThresholds; } - /** - * Gets the value - * - * @return the value - */ - public float getValue() { - return value; - } + /** * Gets the thresholds From 8cdc77637cac4cf4e91f717483bff6966a46bec5 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 18 Apr 2021 18:32:15 -0400 Subject: [PATCH 93/97] fix result to reshape slice to scalar. --- .../java/org/tensorflow/framework/metrics/Precision.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java index 8b0de5042bf..3812e799b75 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -340,11 +340,12 @@ public List updateStateList( public Operand result() { Ops tf = getTF(); Operand result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); - return thresholds.length == 1 - ? tf.slice( + return thresholds.length == 1 + ? tf.reshape(tf.slice( result, tf.expandDims(tf.constant(0), tf.constant(0)), - tf.expandDims(tf.constant(1), tf.constant(0))) + tf.expandDims(tf.constant(1), tf.constant(0))), + tf.constant(Shape.scalar())) : result; } From 0a19b80cc8e7832e0131a9fcc88367fbc8564814 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 18 Apr 2021 18:33:19 -0400 Subject: [PATCH 94/97] Change Collections.EMPTY_LIST to Collections.emtpyList(), remove SuppressWarning("unchecked"); --- .../framework/metrics/impl/WeightsBroadcastOps.java | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index aa51da729a8..6583465da2e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -51,7 +51,6 @@ public class WeightsBroadcastOps { * @param the type of weights and values * @throws IllegalArgumentException If static checks determine {@code weights} has incorrect shape. */ - @SuppressWarnings("unchecked") public static Op assertBroadcastable( Ops tf, Operand weights, Operand values) { Operand weightsShape = tf.shape(weights); @@ -67,7 +66,7 @@ public static Op assertBroadcastable( if (weightsRankStatic != -1 && valuesRankStatic != -1) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") - .withControlDependencies(Collections.EMPTY_LIST) + .withControlDependencies(Collections.emptyList()) .noOp(); } if (weightsRankStatic != valuesRankStatic) { @@ -77,8 +76,8 @@ public static Op assertBroadcastable( ASSERT_BROADCASTABLE_ERROR_PREFIX, valuesRankStatic, weightsRankStatic, - valuesShapeStatic.toString(), - weightsShapeStatic.toString())); + valuesShapeStatic, + weightsShapeStatic)); } for (int i = 0; i < valuesRankStatic; i++) { @@ -88,12 +87,12 @@ public static Op assertBroadcastable( "%s Mismatch at dim %s. values.shape=%s weights.shape=%s.", ASSERT_BROADCASTABLE_ERROR_PREFIX, i, - valuesShapeStatic.toString(), - weightsShapeStatic.toString())); + valuesShapeStatic, + weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") - .withControlDependencies(Collections.EMPTY_LIST) + .withControlDependencies(Collections.emptyList()) .noOp(); } // Dynamic checks. From 82dbbde4b25b70de1902c0a3a0c28133aa0782a6 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 18 Apr 2021 18:34:18 -0400 Subject: [PATCH 95/97] Add test testCummulative(), to make sure multiple calls were adding to the CM variables. --- .../tensorflow/framework/metrics/AUCTest.java | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java index 857a5c93f7a..f9ebfe76cb3 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -23,6 +23,7 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -62,6 +63,52 @@ public void testValueIsIdempotent() { } } + @Test + public void testCummulative() { + + // expected variable values after each run. + float[][] tp = {{2f, 1f, 0f}, {4f, 2f, 0f}, {6f, 3f, 0f}}; + float[][] fp = {{2f, 0f, 0f}, {4f, 0f, 0f}, {6f, 0f, 0f}}; + float[][] tn = {{0f, 2f, 2f}, {0f, 4f, 4f}, {0f, 6f, 6f}}; + float[][] fn = {{0f, 1f, 2f}, {0f, 2f, 4f}, {0f, 3f, 6f}}; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(predArray); + Operand yTrue = tf.constant(trueArray); + AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + + session.run(tf.init()); + + assertNull(instance.getTruePositives()); + assertNull(instance.getFalsePositives()); + assertNull(instance.getTrueNegatives()); + assertNull(instance.getFalseNegatives()); + + + + + for (int i = 0; i < 3; i++) { + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + session.evaluate(tp[i], instance.getTruePositives()); + session.evaluate(fp[i], instance.getFalsePositives()); + session.evaluate(tn[i], instance.getTrueNegatives()); + session.evaluate(fn[i], instance.getFalseNegatives()); + } + + // test reset + session.run(instance.resetStates()); + for (int i = 0; i < 3; i++) { + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + session.evaluate(tp[i], instance.getTruePositives()); + session.evaluate(fp[i], instance.getFalsePositives()); + session.evaluate(tn[i], instance.getTrueNegatives()); + session.evaluate(fn[i], instance.getFalseNegatives()); + } + } + } + @Test public void basicTestSampleWeight() { try (TestSession session = TestSession.createTestSession(tfMode)) { @@ -71,7 +118,7 @@ public void basicTestSampleWeight() { float[] expectedThresholds = new float[] {-1e-7f, 0.5f, 1 + 1e-7f}; assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); - instance.resetStates(); + Operand yPred = tf.constant(new float[] {0, 0, 1, 1}); Operand yTrue = tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}); Operand sampleWeights = tf.constant(new float[] {1, 0, 0, 1}); @@ -113,6 +160,11 @@ public void testUnweighted() { // float expectedResult = (0.75f * 1 + 0.25f * 0); session.evaluate(0.75f, result); + + session.run(update); + result = instance.result(); + + session.print(result); } } From 47af1160bbc951bcacb4514ab17725522e5aa7bc Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 18 Apr 2021 18:36:01 -0400 Subject: [PATCH 96/97] Fix typo in testCumulative method name --- .../src/test/java/org/tensorflow/framework/metrics/AUCTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java index f9ebfe76cb3..4efd25f71c7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -64,7 +64,7 @@ public void testValueIsIdempotent() { } @Test - public void testCummulative() { + public void testCumulative() { // expected variable values after each run. float[][] tp = {{2f, 1f, 0f}, {4f, 2f, 0f}, {6f, 3f, 0f}}; From 1f40a81048b12940f1fb3c5c11582052ebfc3af9 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 18 Apr 2021 18:37:24 -0400 Subject: [PATCH 97/97] remove print statememt --- .../test/java/org/tensorflow/framework/metrics/AUCTest.java | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java index 4efd25f71c7..bd693da1312 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -160,11 +160,6 @@ public void testUnweighted() { // float expectedResult = (0.75f * 1 + 0.25f * 0); session.evaluate(0.75f, result); - - session.run(update); - result = instance.result(); - - session.print(result); } }