diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 7d6d159f5ef..acbae4dac6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -364,10 +364,10 @@ public final class Ops { public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -390,8 +390,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index d3094b5e9e9..306361959bf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -42,6 +42,7 @@ public Constraint(Ops tf) { * * @param weights the weights * @return the constrained weights + * @param the data type for weights and results. */ public abstract Operand call(Operand weights); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java new file mode 100644 index 00000000000..bc5047d5855 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java @@ -0,0 +1,1075 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.framework.metrics.impl.SymbolicShape; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that computes the approximate AUC (Area under the curve) via a Riemann sum. + * + *

This metric creates four local variables, {@code truePositives}, {@code trueNegatives + * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the + * AUC. To discretize the AUC curve, a linearly spaced set of thresholds is used to compute pairs of + * recall and precision values. The area under the ROC-curve is therefore computed using the height + * of the recall values by the false positive rate, while the area under the PR-curve is the + * computed using the height of the precision values by the recall. + * + *

This value is ultimately returned as {@code auc}, an idempotent operation that computes + * the area under a discretized curve of precision versus recall values (computed using the + * aforementioned variables). The {@code numThresholds} variable controls the degree of + * discretization with larger numbers of thresholds more closely approximating the true AUC. The + * quality of the approximation may vary dramatically depending on {@code numThresholds}. The + * {@code thresholds} parameter can be used to manually specify thresholds which split the + * predictions more evenly. + * + *

For best results, {@code predictions} should be distributed approximately uniformly in + * the range [0, 1] and not peaked around 0 or 1. The quality of the AUC approximation may be poor + * if this is not the case. Setting {@code summationMethod} to {@code minoring} or {@code + * majoring} can help quantify the error in the approximation by providing lower or upper + * bound estimate of the AUC. + * + *

Usage:
+ * + *

+ * AUC m = new  org.tensorflow.framework.metrics.AUC( tf, 3);
+ * m.updateState( tf.constant(new float[] {0, 0, 1,1}),
+ *          tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}));
+ *
+ * // threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]
+ * // tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]
+ * // recall = [1, 0.5, 0], fpRate = [1, 0, 0]
+ * // auc = ((((1+0.5)/2)*(1-0))+ (((0.5+0)/2)*(0-0))) = 0.75
+ * Operand<TFloat32> result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 0.75
+ * 
+ * + *
+ * m.resetStates()
+ * m.updateState( tf.constant(new float[] {0, 0, 1, 1}),
+ *                 tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}, ),
+ *                 tf.constant(new float[] {1, 0, 0, 1}));
+ * result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 1.0
+ * 
+ * + * @param The data type for the metric result + */ +public class AUC extends Metric { + + /** Default Fuzz factor. */ + public static final float EPSILON = 1e-7f; + + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; + public static final String TRUE_NEGATIVES = "TRUE_NEGATIVES"; + public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; + public static final int DEFAULT_NUM_THRESHOLDS = 200; + public static final String DEFAULT_NAME = "auc"; + + private final int numThresholds; + private final AUCCurve curve; + private final AUCSummationMethod summationMethod; + private final float[] thresholds; + private final boolean multiLabel; + private final String truePositivesName; + private final String falsePositivesName; + private final String trueNegativesName; + private final String falseNegativesName; + private final Map> initializers = new HashMap<>(); + private final Class type; + + /** The size of the label dimension. */ + private Integer numLabels; + + private Operand labelWeights; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ + private Variable truePositives; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ + private Variable falsePositives; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ + private Variable trueNegatives; + + /** + * If not {@link #multiLabel}, shape (T) where T is the number of thresholds. + * + *

If {@link #multiLabel}, shape (T, C0) where T is the number of thresholds and C0 is a single + * class dimension within each example. + */ + private Variable falseNegatives; + + private boolean initialized; + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link #DEFAULT_NUM_THRESHOLDS} for the numThresholds, {@link AUCCurve#ROC} for the curve type, + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, long seed, Class type) { + this( + tf, + null, + DEFAULT_NUM_THRESHOLDS, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NUM_THRESHOLDS} for the + * numThresholds, {@link AUCCurve#ROC} for the curve type, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for thresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, long seed, Class type) { + this( + tf, + name, + DEFAULT_NUM_THRESHOLDS, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the + * summation method, {@code null} for thresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, int numThresholds, long seed, Class type) { + this( + tf, + null, + numThresholds, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the + * summation method, {@code null} for numThresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, float[] thresholds, long seed, Class type) { + this( + tf, + null, + DEFAULT_NUM_THRESHOLDS, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric. using {@link AUCCurve#ROC} for the curve type, + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, int numThresholds, long seed, Class type) { + this( + tf, + name, + numThresholds, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link + * AUCCurve#ROC} for the curve type, {@link AUCSummationMethod#INTERPOLATION} for the summation + * method, {@link #DEFAULT_NUM_THRESHOLDS} num thresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, float[] thresholds, long seed, Class type) { + this( + tf, + name, + DEFAULT_NUM_THRESHOLDS, + AUCCurve.ROC, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link AUCSummationMethod#INTERPOLATION} for + * the summation method, {@code null} for thresholds, {@code false} for multiLabel, and + * {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, int numThresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + name, + numThresholds, + curve, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, {@link #DEFAULT_NUM_THRESHOLDS} num + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, String name, float[] thresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + name, + DEFAULT_NUM_THRESHOLDS, + curve, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@link AUCSummationMethod#INTERPOLATION} for the summation method, {@code null} for + * thresholds, {@code false} for multiLabel, and {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, int numThresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + null, + numThresholds, + curve, + AUCSummationMethod.INTERPOLATION, + null, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@code null} for numThresholds, {@link + * AUCSummationMethod#INTERPOLATION} for the summation method, {@code false} for multiLabel, + * and {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC(Ops tf, float[] thresholds, AUCCurve curve, long seed, Class type) { + this( + tf, + null, + DEFAULT_NUM_THRESHOLDS, + curve, + AUCSummationMethod.INTERPOLATION, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric. using {@link #DEFAULT_NAME} for the metric name,, + * {@code null} for thresholds, {@code false} for multiLabel, and {@code null} for + * labelWeights. + * + * @param tf The TensorFlow Ops + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + int numThresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, null, numThresholds, curve, summationMethod, null, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric using {@link #DEFAULT_NAME} for the metric name, + * {@code null} for numThresholds, {@code false} for multiLabel, and {@code null} + * for labelWeights. + * + * @param tf The TensorFlow Ops + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + float[] thresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this( + tf, + null, + DEFAULT_NUM_THRESHOLDS, + curve, + summationMethod, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric. using {@code null} for thresholds, {@code + * false} for multiLabel, and {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} + * @param numThresholds the number of thresholds to use when discretizing the roc curve. Values + * must be > 1. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used, + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + String name, + int numThresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this(tf, name, numThresholds, curve, summationMethod, null, false, null, seed, type); + } + + /** + * Creates an AUC (Area under the curve) metric. using {@code null} for the numThresholds, + * {@code false} for multiLabel, and {@code null} for labelWeights. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if {@code null} defaults to {@link #DEFAULT_NAME} + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + */ + public AUC( + Ops tf, + String name, + float[] thresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + long seed, + Class type) { + this( + tf, + name, + DEFAULT_NUM_THRESHOLDS, + curve, + summationMethod, + thresholds, + false, + null, + seed, + type); + } + + /** + * Creates an AUC (Area under the curve) metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. + * @param numThresholds the number of thresholds to use when discretizing the roc curve. This + * includes the bracketing 0 and 1 thresholds, so the value must be ≥ 2. + * @param curve specifies the type of the curve to be computed, {@link AUCCurve#ROC} or {@link + * AUCCurve#PR} for the Precision-Recall-curve. + * @param summationMethod Specifies the Riemann summation method used + * @param thresholds Optional values to use as the thresholds for discretizing the curve. If set, + * the numThresholds parameter is ignored. Values should be in [0, 1]. This method + * automatically brackets the provided {@code thresholds} with a (-{@link #EPSILON}) + * below and a (1 + {@link #EPSILON}) above. + * @param multiLabel boolean indicating whether multilabel data should be treated as such, wherein + * AUC is computed separately for each label and then averaged across labels, or (when false) + * if the data should be flattened into a single label before AUC computation. In the latter + * case, when multilabel data is passed to AUC, each label-prediction pair is treated as an + * individual data point. Should be set to {@code false} for multi-class data. + * @param labelWeights non-negative weights used to compute AUCs for multilabel data. When {@code + * multiLabel} is true, the weights are applied to the individual label AUCs when they + * are averaged to produce the multi-label AUC. When it's false, they are used to weight the + * individual label predictions in computing the confusion matrix on the flattened data. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the confusion matrix variables. + * @throws IllegalArgumentException if numThresholds is less than 2 and thresholds is null, or if + * a threshold value is less than 0 or greater than 1. + */ + public AUC( + Ops tf, + String name, + int numThresholds, + AUCCurve curve, + AUCSummationMethod summationMethod, + float[] thresholds, + boolean multiLabel, + Operand labelWeights, + long seed, + Class type) { + super(tf, name == null ? DEFAULT_NAME : name, seed); + truePositivesName = getVariableName(TRUE_POSITIVES); + falsePositivesName = getVariableName(FALSE_POSITIVES); + trueNegativesName = getVariableName(TRUE_NEGATIVES); + falseNegativesName = getVariableName(FALSE_NEGATIVES); + this.curve = curve; + this.summationMethod = summationMethod; + this.type = type; + + this.multiLabel = multiLabel; + + if (thresholds != null) { // ignore numThresholds + for (float t : thresholds) { + if (t < 0.0f || t > 1.0f) { + throw new IllegalArgumentException( + String.format( + "Threshold values must be in range [0, 1], inclusive. Invalid values: %s", + Arrays.toString(thresholds))); + } + } + this.numThresholds = thresholds.length + 2; + Arrays.sort(thresholds); + } else { + + if (numThresholds <= 1) { + throw new IllegalArgumentException("numThresholds must be > 1."); + } + this.numThresholds = numThresholds; + thresholds = new float[this.numThresholds - 2]; + // linearly interpolate (numThresholds - 2) thresholds between endpoints + for (int i = 0; i < thresholds.length; i++) { + thresholds[i] = (i + 1) * 1.0f / (this.numThresholds - 1); + } + } + + // Add an endpoint "threshold" below zero and above one for either + // threshold method to account for floating point imprecisions. + this.thresholds = new float[this.numThresholds]; + this.thresholds[0] = -EPSILON; + System.arraycopy(thresholds, 0, this.thresholds, 1, thresholds.length); + this.thresholds[this.numThresholds - 1] = 1 + EPSILON; + + // Handle multilabel arguments. + + if (labelWeights != null) { + // assert that labelWeights are non-negative. + + this.labelWeights = labelWeights; + Op checks = + tf.withSubScope("AUC") + .assertThat( + tf.math.greaterEqual(labelWeights, cast(tf, tf.constant(0), labelWeights.type())), + Collections.singletonList( + tf.constant("All values of labelWeights must be non-negative."))); + + Ops ltf = + tf.withSubScope("updateState").withControlDependencies(Collections.singletonList(checks)); + + this.labelWeights = ltf.identity(this.labelWeights); + } + + if (multiLabel) { + numLabels = null; + } + } + + /** + * Initialize truePositives, falsePositives, trueNegatives, and falseNegatives variables, given + * the shape of the data. + * + * @param shape the prediction shape if called from updateState, otherwise null + */ + @SuppressWarnings("unchecked") + private Map> build(Shape shape) { + Shape variableShape; + if (initialized) { + return Collections.EMPTY_MAP; + } + Ops tf = getTF(); + + if (isMultiLabel()) { + if (shape == null) { + throw new IllegalArgumentException("For multiLabel, a shape must be provided"); + } + if (shape.numDimensions() != 2) + throw new IllegalArgumentException( + String.format( + "labels must have rank=2 when multiLabel is true. Found rank %d.", + shape.numDimensions())); + numLabels = (int) shape.size(1); + variableShape = Shape.of(numThresholds, numLabels); + } else { + variableShape = Shape.of(numThresholds); + } + + // Create metric variables + Zeros zeros = new Zeros<>(tf); + Operand zero = zeros.call(tf.constant(variableShape), type); + if (truePositives == null) { + truePositives = tf.withName(getTruePositivesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.TRUE_POSITIVES, tf.assign(truePositives, zero)); + } + + if (falsePositives == null) { + falsePositives = tf.withName(getFalsePositivesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.FALSE_POSITIVES, tf.assign(falsePositives, zero)); + } + + if (trueNegatives == null) { + trueNegatives = tf.withName(getTrueNegativesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.TRUE_NEGATIVES, tf.assign(trueNegatives, zero)); + } + + if (falseNegatives == null) { + falseNegatives = tf.withName(getFalseNegativesName()).variable(zero); + initializers.put(ConfusionMatrixEnum.FALSE_NEGATIVES, tf.assign(falseNegatives, zero)); + } + + initialized = true; + return initializers; + } + + /** + * Creates a List of Operations to update the metric state based on labels and predictions. + * + * @param labels shape (N, Cx, L1?) where N is the number of examples, Cx is zero or more class + * dimensions, and L1 is a potential extra dimension of size 1 that would be squeezed. Will be + * cast to {@code }. If {@link #multiLabel} or if {@link #labelWeights} {@code != null + * }, then Cx must be a single dimension. + * @param predictions the predictions shape (N, Cx, P1?). Will be cast to {@code T}. + * @param sampleWeights sample weights to be applied to values, may be null. Will be cast to + * {@code }. + * @return a List of Operations to update the metric state + */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Ops tf = getTF(); + Operand tLabels = cast(tf, labels, type); + Operand tPredictions = cast(tf, predictions, type); + Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; + List updateOperations = new ArrayList<>(); + Map> varInitializers = Collections.EMPTY_MAP; + if (!initialized) { + varInitializers = build(tPredictions.shape()); + } + if (isMultiLabel() || getLabelWeights() != null) { + // labels should have shape (number of examples, number of labels). + List> symbols = new ArrayList<>(); + symbols.add(new SymbolicShape<>(tLabels, "N", "L")); + if (isMultiLabel()) { + // TP, TN, FP, and FN should all have shape + // (number of thresholds, number of labels). + symbols.add(new SymbolicShape<>(truePositives, "T", "L")); + symbols.add(new SymbolicShape<>(falsePositives, "T", "L")); + symbols.add(new SymbolicShape<>(trueNegatives, "T", "L")); + symbols.add(new SymbolicShape<>(falseNegatives, "T", "L")); + } + if (getLabelWeights() != null) { + symbols.add(new SymbolicShape<>(getLabelWeights(), "L")); + } + updateOperations.addAll( + MetricsHelper.assertShapes(tf, symbols, "Number of labels is not consistent.")); + } + + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, falsePositives); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, trueNegatives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, falseNegatives); + + // Only forward labelWeights to update_confusion_matrix_variables when + // multiLabel is false. Otherwise the averaging of individual label AUCs is + // handled in AUC.result + updateOperations.addAll( + MetricsHelper.updateConfusionMatrixVariables( + tf, + confusionMatrix, + varInitializers, + tLabels, + tPredictions, + tf.constant(thresholds), + null, + null, + tSampleWeights, + isMultiLabel(), + isMultiLabel() ? null : getLabelWeights())); + return updateOperations; + } + + /** + * Gets the input with all positive numbers. Negative numbers are set to 0. + * + * @param input the input + * @return the input with all positive numbers. + */ + private Operand positive(Operand input) { + return getTF().math.maximum(input, cast(getTF(), getTF().constant(0), input.type())); + } + + /** + * Gets the truth value of whether {@code input > 0}, element-wise. + * + * @param input the input + * @return the truth value of whether {@code input > 0}, element-wise. + */ + private Operand isPositive(Operand input) { + return getTF().math.greater(input, cast(getTF(), getTF().constant(0), input.type())); + } + + /** + * Extracts a slice from the input. + * + * @param input the input + * @param begin the beginning location of the slice + * @param size the size of the slice + * @return the slice + */ + private Operand slice(Operand input, int begin, int size) { + return getTF() + .slice(input, getTF().constant(new int[] {begin}), getTF().constant(new int[] {size})); + } + + /** + * Interpolation formula inspired by section 4 of Davis & Goadrich 2006. + * + *

Note here we derive & use a closed formula not present in the paper as follows: + * + *

+   *     Precision = TP / (TP + FP) = TP / P
+   * 
+ * + *

Modeling all of TP (true positive), FP (false positive) and their sum P = TP + FP (predicted + * positive) as varying linearly within each interval [A, B] between successive thresholds, we get + * + *

+   *     Precision slope = dTP / dP
+   *                     = (TP_B - TP_A) / (P_B - P_A)
+   *                     = (TP - TP_A) / (P - P_A)
+   *     Precision = (TP_A + slope * (P - P_A)) / P
+   * 
+ * + *

The area within the interval is (slope / total_pos_weight) times + * + *

+   *       int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}
+   *       int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}
+   * 
+ * + * where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in + * + *
+   *       int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)
+   * 
+ * + * Bringing back the factor (slope / total_pos_weight) we'd put aside, we get + * + *
+   *       slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight
+   * 
+ * + * where dTP == TP_B - TP_A. Note that when P_A == 0 the above calculation simplifies into + * + *
+   *       int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)
+   * 
+ * + * which is really equivalent to imputing constant precision throughout the first bucket having >0 + * true positives. + * + * @return an approximation of the area under the P-R curve. + * @see The Relationship Between + * Precision-Recall and ROC Curves - Davis & Goadrich 2006 + */ + private Operand interpolatePRAuc() { + // truePositives[:self.numThresholds - 1] + Ops tf = getTF(); + Operand tp0 = slice(truePositives, 0, getNumThresholds() - 1); + // truePositives[1:] + Operand tp1 = slice(truePositives, 1, -1); + + Operand dTP = tf.math.sub(tp0, tp1); + + Operand p = tf.math.add(truePositives, falsePositives); + Operand p0 = slice(p, 0, getNumThresholds() - 1); + Operand p1 = slice(p, 1, -1); + + Operand dP = tf.math.sub(p0, p1); + + Operand precisionSlope = tf.math.divNoNan(dTP, positive(dP)); + + Operand intercept = tf.math.sub(tp1, tf.math.mul(precisionSlope, p1)); + + Operand safePRatio = + tf.select( + tf.math.logicalAnd(isPositive(p0), isPositive(p1)), + tf.math.divNoNan(p0, positive(p1)), + tf.onesLike(p1)); + + Operand fn1 = slice(falseNegatives, 1, -1); + + Operand aucTotalPos = + tf.math.mul( + precisionSlope, tf.math.add(dTP, tf.math.mul(intercept, tf.math.log(safePRatio)))); + + Operand prAucIncrement = tf.math.divNoNan(aucTotalPos, positive(tf.math.add(tp1, fn1))); + + if (isMultiLabel()) { + Operand byLabelAuc = tf.reduceSum(prAucIncrement, tf.constant(0)); + if (getLabelWeights() == null) { + // Evenly weighted average of the label AUCs. + return MetricsHelper.mean(tf, byLabelAuc); + } else { + // Weighted average of the label AUCs. + return tf.math.divNoNan( + tf.reduceSum(tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, byLabelAuc)), + tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))); + } + } else { + return tf.reduceSum(prAucIncrement, allAxes(tf, prAucIncrement)); + } + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + + if (getCurve() == AUCCurve.PR && getSummationMethod() == AUCSummationMethod.INTERPOLATION) { + // This use case is different and is handled separately. + return interpolatePRAuc(); + } + Ops tf = getTF(); + Operand x; + Operand y; + Operand recall = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + + switch (getCurve()) { + case ROC: + x = tf.math.divNoNan(falsePositives, tf.math.add(falsePositives, trueNegatives)); + y = recall; + break; + case PR: + y = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + x = recall; + break; + default: + throw new IllegalArgumentException("Unexpected AUCCurve value: " + getCurve()); + } + + // Find the rectangle heights based on `summationMethod`. + // y[:self.numThresholds - 1] + Operand ySlice1 = slice(y, 0, getNumThresholds() - 1); + // y[1:] + Operand ySlice2 = slice(y, 1, -1); + + Operand heights; + switch (getSummationMethod()) { + case INTERPOLATION: + heights = tf.math.div(tf.math.add(ySlice1, ySlice2), cast(tf, tf.constant(2), y.type())); + break; + case MINORING: + heights = tf.math.minimum(ySlice1, ySlice2); + break; + case MAJORING: + heights = tf.math.maximum(ySlice1, ySlice2); + break; + default: + throw new IllegalArgumentException( + "Unexpected AUCSummationMethod value: " + getSummationMethod()); + } + + if (isMultiLabel()) { + Operand riemannTerms = + tf.math.mul(tf.math.sub(slice(x, 0, getNumThresholds() - 1), slice(x, 1, -1)), heights); + Operand byLabelAuc = tf.reduceSum(riemannTerms, tf.constant(0)); + + if (getLabelWeights() == null) { + return MetricsHelper.mean(tf, byLabelAuc); + } else { + // Weighted average of the label AUCs. + return tf.math.divNoNan( + tf.reduceSum( + tf.math.mul(byLabelAuc, getLabelWeights()), allAxes(tf, getLabelWeights())), + tf.reduceSum(getLabelWeights(), allAxes(tf, getLabelWeights()))); + } + + } else { + Operand slice1 = slice(x, 0, getNumThresholds() - 1); + Operand slice2 = slice(x, 1, -1); + Operand sub = tf.math.sub(slice1, slice2); + Operand operand = tf.math.mul(sub, heights); + return tf.reduceSum(operand, allAxes(tf, operand)); + } + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + List updateOperations = new ArrayList<>(initializers.values()); + return getTF().withSubScope("resetStates").withControlDependencies(updateOperations).noOp(); + } + + /** @return the numThresholds */ + public int getNumThresholds() { + return numThresholds; + } + + /** @return the curve */ + public AUCCurve getCurve() { + return curve; + } + + /** @return the summationMethod */ + public AUCSummationMethod getSummationMethod() { + return summationMethod; + } + + /** @return the thresholds */ + public float[] getThresholds() { + return thresholds; + } + + /** @return the multiLabel */ + public boolean isMultiLabel() { + return multiLabel; + } + + /** @return the numLabels */ + public Integer getNumLabels() { + return numLabels; + } + + /** @param numLabels the numLabels to set */ + public void setNumLabels(Integer numLabels) { + this.numLabels = numLabels; + } + + /** @return the labelWeights */ + public Operand getLabelWeights() { + return labelWeights; + } + + /** @return the truePositives */ + public Variable getTruePositives() { + return truePositives; + } + + /** @return the falsePositives */ + public Variable getFalsePositives() { + return falsePositives; + } + + /** @return the trueNegatives */ + public Variable getTrueNegatives() { + return trueNegatives; + } + + /** @return the falseNegatives */ + public Variable getFalseNegatives() { + return falseNegatives; + } + + /** @return the truePositivesName */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** @return the falsePositivesName */ + public String getFalsePositivesName() { + return falsePositivesName; + } + + /** @return the trueNegativesName */ + public String getTrueNegativesName() { + return trueNegativesName; + } + + /** @return the falseNegativesName */ + public String getFalseNegativesName() { + return falseNegativesName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java new file mode 100644 index 00000000000..b5426a0dd8f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCCurve.java @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +/** + * Specifies the type of the curve to be computed, {@link #ROC} for a Receiver Operator + * Characteristic curve [default] or {@link #PR} for a Precision-Recall-curve. + */ +public enum AUCCurve { + /** Receiver Operator Characteristic curve */ + ROC, + /** Precision-Recall-curve */ + PR; + + /** + * Gets the AUCCurve enum value by name, regardless of case + * + * @param name the name of the AUCCurve enum value. + * @return the AUCCurve enum value. + */ + public AUCCurve get(String name) { + return AUCCurve.valueOf(name.toUpperCase()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java new file mode 100644 index 00000000000..735d97ecf09 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUCSummationMethod.java @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +/** + * Specifies the Riemann summation method used. + * + * @see Davis & Goadrich. 2006 + * @see Riemann summation method + */ +public enum AUCSummationMethod { + /** + * Apply mid-point summation scheme for {@link AUCCurve#ROC}. For {@link AUCCurve#PR}, + * interpolates (true/false) positives but not the ratio that is precision + */ + INTERPOLATION, + /** Apply right summation for increasing intervals and left summation for decreasing intervals */ + MAJORING, + /** Apply left summation for increasing intervals and right summation for decreasing intervals */ + MINORING; + + /** + * Gets the AUCSummationMethod enum value by name, regardless of case + * + * @param name the name of the AUCSummationMethod enum value. + * @return the AUCSummationMethod enum value. + */ + public AUCSummationMethod get(String name) { + return AUCSummationMethod.valueOf(name.toUpperCase()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java new file mode 100644 index 00000000000..516d6c91ba6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Accuracy.java @@ -0,0 +1,98 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that calculates how often predictions equals labels. + * + *

This metric creates two local variables, total and count that are used to compute the + * frequency with which {@code predictions} matches {@code labels}. This frequency is + * ultimately returned as binary accuracy: an idempotent operation that simply divides total by + * count. + * + *

If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask + * values. + * + * @param The data type for the metric result + */ +public class Accuracy extends MeanMetricWrapper implements LossMetric { + + /** + * Creates an Accuracy Metric using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Accuracy(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates an Accuracy Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Accuracy(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** + * Calculates how often predictions equals labels. {@code labels} and {@code predictions} must + * have compatible shapes, see {@link Shape @isCompatibleWith}. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @throws IllegalArgumentException if predictions and labels shapes are not compatible. + * @return the loss + */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + LossTuple tuple = + MetricsHelper.raggedAssertCompatibleAndGetFlatValues(getTF(), tLabels, tPredictions); + tLabels = tuple.getLabels(); + tPredictions = tuple.getTarget(); + + if (!predictions.shape().isCompatibleWith(labels.shape())) { + throw new IllegalArgumentException( + String.format( + "Shapes %s and %s are incompatible", + predictions.shape().toString(), labels.shape().toString())); + } + + // cast TBool to result type + return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java new file mode 100644 index 00000000000..0e41699e165 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryAccuracy.java @@ -0,0 +1,106 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that calculates how often predictions matches binary labels. + * + *

This metric creates two local variables, total and count that are used to compute the + * frequency with which {@code predictions} matches {@code labels}. This frequency is + * ultimately returned as binary accuracy: an idempotent operation that simply divides total by + * count. + * + *

If sampleWeights is {@code null}, weights default to 1. Use sampleWeights of 0 to mask + * values. + * + * @param The data type for the metric result + */ +public class BinaryAccuracy extends MeanMetricWrapper + implements LossMetric { + /** the default threshold value for deciding whether prediction values are 1 or 0 */ + public static final float DEFAULT_THRESHOLD = 0.5f; + + /** the threshold value for deciding whether prediction values are 1 or 0 */ + private final float threshold; + + /** + * Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name and + * {@link #DEFAULT_THRESHOLD} for the threshold value. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryAccuracy(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold for deciding whether prediction values are 1 or 0 + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryAccuracy(Ops tf, float threshold, long seed, Class type) { + this(tf, null, threshold, seed, type); + } + + /** + * Creates a BinaryAccuracy Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold for deciding whether prediction values are 1 or 0 + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public BinaryAccuracy(Ops tf, String name, float threshold, long seed, Class type) { + super(tf, name, seed, type); + this.threshold = threshold; + setLoss(this); + } + + /** + * Calculates how often predictions match binary labels. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Binary accuracy values. shape = {@code [batch_size, d0, .. dN-1]} + */ + @Override + public Operand call( + Operand labels, Operand predictions) { + + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Operand thresholdCast = cast(getTF(), getTF().constant(threshold), getResultType()); + tPredictions = + cast(getTF(), getTF().math.greater(tPredictions, thresholdCast), getResultType()); + Operand tLabels = cast(getTF(), labels, getResultType()); + return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java index 48ee244eafb..57a6f75375d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -60,7 +60,14 @@ public BinaryCrossentropy( this.labelSmoothing = labelSmoothing; } - /** {@inheritDoc} */ + /** + * Computes the binary crossentropy loss between labels and predictions. + * + * @param labels the truth values or labels, has the same shape as predictions and shape = {@code + * [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Binary crossentropy loss value. shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java new file mode 100644 index 00000000000..dece2d1cd50 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalAccuracy.java @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that calculates how often predictions matches one-hot labels. + * + *

You can provide {@code logits} of classes as {@code predictions}, since argmax of + * {@code logits} and probabilities are same. + * + *

This metric creates two local variables, {@code total} and {@code count} that are + * used to compute the frequency with which {@code predictions} matches {@code labels}. + * This frequency is ultimately returned as categorical accuracy: an idempotent operation that + * simply divides total by count. + * + *

{@code predictions} and {@code labels} should be passed in as vectors of + * probabilities, rather than as labels. If necessary, use {@link + * org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand + * {@code labels} as a vector. + * + *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. + * + * @param The data type for the metric result + */ +public class CategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a CategoricalAccuracy metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public CategoricalAccuracy(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a CategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public CategoricalAccuracy(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + super.setLoss(this); + } + + /** + * Computes the categorical crossentropy loss. + * + *

{@code predictions} and {@code labels} should be passed in as vectors of probabilities, + * rather than as labels. If necessary, use {@link Ops#oneHot} to expand {@code labels} as a + * vector. + * + * @param labels One-hot ground truth values. + * @param predictions tThe prediction values. + * @return Categorical accuracy values. + */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand trueMax = getTF().math.argMax(labels, getTF().constant(-1)); + + Operand predMax = getTF().math.argMax(predictions, getTF().constant(-1)); + return cast(getTF(), getTF().math.equal(trueMax, predMax), getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java index b22e5415f79..58aa51f664c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -28,9 +28,9 @@ * labels. * *

This is the crossentropy metric class to be used when there are multiple label classes (2 or - * more). The labels should be given as a one_hot representation. eg., When labels values are - * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] - * . + * more). The labels should be given as a one_hot representation. eg., When labels values are {@code + * [2, 0, 1]}, the labels Operand contains = {@code [[0, 0, 1], [1, 0, 0], [0, 1, 0]] + * }. * * @param The data type for the metric result */ @@ -52,9 +52,9 @@ public class CategoricalCrossentropy extends MeanMetricWrappe * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to * a probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 - * means that we will use a value of 0.1 for label 0 and 0.9 - * for label 1 + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} + * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 + * } for label {@code 1} * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result @@ -73,13 +73,13 @@ public CategoricalCrossentropy( * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a * probability distribution. * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, - * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 - * means that we will use a value of 0.1 for label 0 and 0.9 - * for label 1 - * @param axis Int specifying the channels axis. axis={@link Losses#CHANNELS_LAST} - * corresponds to data format channels_last, and - * axis={@link Losses#CHANNELS_FIRST} corresponds to data format - * channels_first. + * meaning the confidence on label values are relaxed. e.g. {@code labelSmoothing=0.2} + * means that we will use a value of {@code 0.1} for label {@code 0} and {@code 0.9 + * } for label {@code 1} + * @param axis Int specifying the channels axis. {@code axis={@link Losses#CHANNELS_LAST}} + * corresponds to data format {@code channels_last}, and {@code + * axis={@link Losses#CHANNELS_FIRST}} corresponds to data format {@code + * channels_first}. * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the type for the variables and result @@ -99,7 +99,13 @@ public CategoricalCrossentropy( this.axis = axis; } - /** {@inheritDoc} */ + /** + * Computes the crossentropy loss between the labels and predictions. + * + * @param labels the truth values or labels, of one-hot true targets, same shape as predictions + * @param predictions the predictions + * @return Categorical crossentropy loss value. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java index 4266cc487c0..1f6d0fd002c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -45,7 +45,13 @@ public CategoricalHinge(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the categorical hinge metric between {@code labels} and @{code predictions}. + * + * @param labels the truth values or labels, labels values are expected to be 0 or 1. + * @param predictions the predictions + * @return Categorical hinge loss values. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java index 840f255c5ab..230286a738f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -26,7 +26,17 @@ /** * A metric that computes the cosine similarity metric between labels and predictions. * + *

Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 + * indicates orthogonality and values closer to -1 indicate greater similarity. The values closer to + * 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where you + * try to maximize the proximity between predictions and targets. If either labels and predictions + * is a zero vector, cosine similarity will be 0 regardless of the proximity between predictions and + * targets. + * + *

{@code loss = -sum(l2_norm(y_true) * l2_norm(y_pred))}
+ * * @param The data type for the metric result. + * @see Cosine Similarity */ public class CosineSimilarity extends MeanMetricWrapper implements LossMetric { @@ -76,7 +86,13 @@ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class typ setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the cosine similarity loss between labels and predictions. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @return the cosine similarity loss + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java new file mode 100644 index 00000000000..3db7fffc2e9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalseNegatives.java @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of false negatives. + * + *

If {@code sampleWeights} is given, calculates the sum of the weights of false negatives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of + * the number of false negatives. + * + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. + * + * @param The data type for the metric result + */ +public class FalseNegatives extends ConfusionMatrixConditionCount { + + /** + * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalseNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a FalseNegatives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalseNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalseNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalseNegatives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.FALSE_NEGATIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java new file mode 100644 index 00000000000..551529b6179 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/FalsePositives.java @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of false positives. + * + *

If {@code sampleWeights} is given, calculates the sum of the weights of false positives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of + * the number of false positives. + * + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. + * + * @param The data type for the metric result + */ +public class FalsePositives extends ConfusionMatrixConditionCount { + + /** + * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalsePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a FalsePositives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a FalsePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a FalsePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public FalsePositives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.FALSE_POSITIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java index 46ccd2859ff..a2d110867b8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -44,7 +44,13 @@ public Hinge(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the hinge loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return the hinge loss between labels and predictions. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java index 9ffcd6189f1..155a891ccc2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -45,7 +45,13 @@ public KLDivergence(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes Kullback-Leibler divergence metric between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return the loss with shape {@code [batch_size, d0, .. dN-1]} + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java index 59e24f57110..786847d4b32 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -45,7 +45,13 @@ public LogCoshError(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Calculates the Logarithm of the hyperbolic cosine of the prediction error. + * + * @param labels Ground truth values, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Logcosh error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java index 1cc6d0b6f99..b38d0a809e1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -45,7 +45,13 @@ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean absolute error loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Mean absolute error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java index 8c6720b58f6..22bcd0ab0eb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -45,7 +45,13 @@ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean absolute percentage error loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Mean absolute percentage error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java new file mode 100644 index 00000000000..22baab3d6cb --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -0,0 +1,203 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the mean Intersection-Over-Union metric. + * + *

Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, + * which first computes the IOU for each semantic class and then computes the average over classes. + * IOU is defined as follows: {@code IOU = true_positive / (true_positive + false_positive + + * false_negative)}. The predictions are accumulated in a confusion matrix, weighted by + * sample_weight and the metric is then calculated from it. + * + *

If {@code sampleWeight} is {@code null}, weights default to 1. Use sample_weight of 0 to mask + * values. + * + * @param The data type for the metric result + */ +public class MeanIoU extends Metric { + + public static final String TOTAL_CONFUSION_MATRIX = "TOTAL_CONFUSION_MATRIX"; + private final String totalCMName; + private final Class type; + /** + * The possible number of labels the prediction task can have. This value must be provided, since + * a confusion matrix of dimension = [numClasses, numClasses] will be allocated. + */ + private final long numClasses; + + private Variable totalConfusionMatrix; + private Assign initializer; + + /** + * Creates a metric MeanIoU, using name as {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param numClasses The possible number of labels the prediction task can have + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + protected MeanIoU(Ops tf, long numClasses, long seed, Class type) { + this(tf, null, numClasses, seed, type); + } + + /** + * Creates a MeanIoU metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param numClasses The possible number of labels the prediction task can have + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + protected MeanIoU(Ops tf, String name, long numClasses, long seed, Class type) { + super(tf, name, seed); + this.type = type; + this.totalCMName = this.getVariableName(TOTAL_CONFUSION_MATRIX); + this.numClasses = numClasses; + init(); + } + + private void init() { + Shape variableShape = Shape.of(numClasses, numClasses); + + if (totalConfusionMatrix == null) { + Zeros zeros = new Zeros<>(getTF()); + totalConfusionMatrix = + getTF().withName(totalCMName).variable(zeros.call(getTF().constant(variableShape), type)); + initializer = + getTF().assign(totalConfusionMatrix, zeros.call(getTF().constant(variableShape), type)); + } + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return initializer; + } + + /** + * Gets the initializer for the totalConfusionMatrix variable + * + * @return the initializer for the totalConfusionMatrix variable + */ + public Assign getInitializer() { + return initializer; + } + + /** + * Accumulates the confusion matrix statistics. + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either + * 0, or the same rank as labels, and must be broadcastable to labels. + * @return the Operands that updates totalConfusionMatrix variable + * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} labels rank, + * and if the predictions size is not equal to the labels size + */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + if (sampleWeights != null) { + long weightsRank = sampleWeights.shape().numDimensions(); + long labelsRank = labels.shape().numDimensions(); + if (weightsRank != 0 + && weightsRank != Shape.UNKNOWN_SIZE + && labelsRank != Shape.UNKNOWN_SIZE + && weightsRank != labelsRank) { + throw new IllegalArgumentException( + String.format( + "Weights must either have rank 0, or the same rank as labels, weights rank = %d, labels rank = %d", + weightsRank, labelsRank)); + } + } + long labelsSize = labels.shape().size(); + long predictionsSize = predictions.shape().size(); + if (labelsSize != predictionsSize) { + throw new IllegalArgumentException( + String.format( + "labels and predictions must have the same size, labels size = %d, predictions size = %d", + labelsSize, predictionsSize)); + } + + Operand tLabels = cast(getTF(), labels, type); + if (tLabels.shape().numDimensions() > 1) { + tLabels = getTF().shape.flatten(tLabels); + } + Operand tPredictions = cast(getTF(), predictions, type); + if (tPredictions.shape().numDimensions() > 1) { + tPredictions = getTF().shape.flatten(tPredictions); + } + Operand tSampleWeights = sampleWeights != null ? cast(getTF(), sampleWeights, type) : null; + if (tSampleWeights != null && tSampleWeights.shape().numDimensions() > 1) { + tSampleWeights = getTF().shape.flatten(tSampleWeights); + } + + // Accumulate the prediction to current confusion matrix. + Operand currentCM = + MetricsHelper.confusionMatrix( + getTF(), tLabels, tPredictions, getTF().constant(numClasses), tSampleWeights, type); + return Collections.singletonList(getTF().assignAdd(totalConfusionMatrix, currentCM)); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + Operand sumOverRow = tf.reduceSum(totalConfusionMatrix, tf.constant(0)); + Operand sumOverCol = tf.reduceSum(totalConfusionMatrix, tf.constant(1)); + Operand truePositives = + tf.linalg.matrixDiagPart( + totalConfusionMatrix, + tf.constant(0), + cast(tf, tf.constant(0), totalConfusionMatrix.type())); + // for each class, the total predictions + total labels - true positives + // Observe that total predictions = tp + fp + // total labels = tp + fn + // So this is 2*tp + fp + fn - tp = tp + fp + fn + Operand denominator = tf.math.add(sumOverRow, tf.math.sub(sumOverCol, truePositives)); + Operand numValidEntries = + tf.reduceSum( + tf.dtypes.cast( + tf.math.notEqual(denominator, cast(tf, tf.constant(0), denominator.type())), type), + allAxes(tf, denominator)); + Operand iou = tf.math.divNoNan(truePositives, denominator); + + Operand iouSum = tf.reduceSum(iou, allAxes(tf, iou)); + return tf.math.divNoNan(iouSum, numValidEntries); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java new file mode 100644 index 00000000000..acf28f5b2cc --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanRelativeError.java @@ -0,0 +1,178 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the mean relative error by normalizing with the given values. + * + *

This metric creates two local variables, {@code total} and {@code count} that are + * used to compute the mean relative error. This is weighted by {@code sampleWeight}, and it is + * ultimately returned as mean relative error: an idempotent operation that simply divides total by + * count. + * + *

If {@code sampleWeight} is {@code null}, weights default to 1. Use {@code sampleWeight} + * of 0 to mask values. + * + * @param The data type for the metric result + */ +public class MeanRelativeError extends Mean { + private Operand normalizer; + + /** + * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name + * + * @param tf the TensorFlow Ops + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, float[] normalizer, long seed, Class type) { + this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * Creates a MeanRelativeError metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, String name, float[] normalizer, long seed, Class type) { + this(tf, name, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name + * + * @param tf the TensorFlow Ops + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, double[] normalizer, long seed, Class type) { + this(tf, null, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * Creates a MeanRelativeError metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, String name, double[] normalizer, long seed, Class type) { + this(tf, name, cast(tf, tf.constant(normalizer), type), seed, type); + } + + /** + * Creates a MeanRelativeError metric using {@link Class#getSimpleName()} as the name + * + * @param tf the TensorFlow Ops + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError(Ops tf, Operand normalizer, long seed, Class type) { + this(tf, null, normalizer, seed, type); + } + + /** + * Creates a MeanRelativeError metric + * + * @param tf the TensorFlow ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param normalizer The normalizer values with same shape as predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanRelativeError( + Ops tf, String name, Operand normalizer, long seed, Class type) { + super(tf, name, seed, type); + this.normalizer = normalizer; + } + + /** + * Accumulates metric statistics. + * + * @param labels The ground truth values. + * @param predictions The predicted values. Must be the same shape as the normalizer. + * @param sampleWeights Optional weighting of each example. A null value defaults to 1. Can be an + * {@code Operand} whose rank is either 0, or the same rank as {@code labels}, and must be + * broadcastable to {@code labels}. + * @return a List of Operations to update the metric state + */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Operand tSampleWeights = + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + + LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); + tPredictions = tuple.getTarget(); + tLabels = tuple.getLabels(); + + tuple = LossesHelper.removeSqueezableDimensions(getTF(), normalizer, tPredictions); + normalizer = tuple.getLabels(); + tPredictions = tuple.getTarget(); + + if (!tPredictions.shape().isCompatibleWith(tLabels.shape())) + throw new IllegalArgumentException( + String.format( + "Prediction shape %s is not compatible with labels shape %s", + tPredictions.shape(), tLabels.shape())); + + Operand relativeErrors = + getTF() + .math + .divNoNan( + getTF().math.abs(getTF().math.sub(tLabels, tPredictions)), this.getNormalizer()); + + return super.updateStateList(relativeErrors, tSampleWeights); + } + + /** + * Gets the normalizer Operand + * + * @return the normalizer + */ + public Operand getNormalizer() { + return normalizer; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java index 3c4c79d39ba..fd8be29875e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -26,6 +26,19 @@ /** * A metric that computes the mean of absolute difference between labels and predictions. * + *

The {@code MeanSquaredError} class creates two local variables, {@code total} and {@code + * count} that are used to compute the mean squared error. This average is weighted by {@code + * weights}, and it is ultimately returned as the mean squared error: an idempotent operation that + * simply divides {@code total} by {@code count}. + * + *

For estimation of the metric over a stream of data, the function creates an update operation + * that updates these variables. Internally, a squared error operation computes the element-wise + * square of the difference between {@code predictions} and {@code labels}. Then the update + * operation increments {@code total} with the reduced sum of the product of {@code weights} and the + * squared error, and it increments {@code count} with the reduced sum of {@code weights}. + * + *

If {@code weights} is null, weights default to 1. Use weights of 0 to mask values. + * * @param The data type for the metric result. */ public class MeanSquaredError extends MeanMetricWrapper @@ -45,7 +58,13 @@ public MeanSquaredError(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean squared error between the labels and predictions. + * + * @param labels the truth values or labels. Must be the same shape as predictions. + * @param predictions the predictions + * @return Computes the mean squared error between the labels and predictions. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java index d525bb76648..4728cbab12f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -45,7 +45,13 @@ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the mean squared logarithmic error between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Mean squared logarithmic error values, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java new file mode 100644 index 00000000000..d88d7a4c1b4 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanTensor.java @@ -0,0 +1,198 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.framework.metrics.impl.WeightsBroadcastOps; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Metric that computes the element-wise (weighted) mean of the given tensors. + * + * @param The data type for the metric result + */ +public class MeanTensor extends Metric { + public static final String TOTAL = "total"; + public static final String COUNT = "count"; + private final String totalName; + private final String countName; + private final Class type; + private Shape shape; + private Variable total; + private Variable count; + private Assign totalInitializer; + private Assign countInitializer; + private boolean initialized; + + /** + * Creates a MeanTensor metric, using {@link Class#getSimpleName()} as the name + * + * @param tf the TensorFlow ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public MeanTensor(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + /** + * Creates a MeanTensor metric + * + * @param tf the TensorFlow ops + * @param name the name of this metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public MeanTensor(Ops tf, String name, long seed, Class type) { + super(tf, name, seed); + this.type = type; + this.totalName = this.getVariableName(TOTAL); + this.countName = this.getVariableName(COUNT); + } + + /** + * Creates the Operations that initialize the total and count variables. + * + * @param shape the shape of the variables + * @return true if the variables need initialization, otherwise false; + */ + private boolean init(Shape shape) { + if (!initialized) { + this.shape = shape; + Zeros zeros = new Zeros<>(getTF()); + Operand zero = zeros.call(getTF().constant(shape), type); + + if (total == null) { + total = getTF().withName(totalName).variable(zero); + totalInitializer = getTF().assign(total, zero); + } + if (count == null) { + count = getTF().withName(countName).variable(zero); + countInitializer = getTF().assign(count, zero); + } + this.initialized = true; + return true; + } else { + return false; + } + } + + /** + * Accumulates statistics for computing the element-wise mean. + * + * @param values Per-example value. Input values must always have the same shape for all + * invocations of updateStateList. + * @param sampleWeights Optional weighting of each example. Defaults to 1 if null. + * @throws IllegalArgumentException if the shape of {@code values} in each subsequent call is not + * the same shape as {@code values} set during the first call + */ + @Override + public List updateStateList( + Operand values, Operand sampleWeights) { + Ops tf = getTF(); + Operand tValues = cast(tf, values, type); + Operand tSampleWeights = sampleWeights == null ? null : cast(tf, sampleWeights, type); + + // update the shape if it is the first call. + boolean needsInitialization = init(values.shape()); + + if (!this.shape.equals(values.shape())) { + throw new IllegalArgumentException( + String.format( + "MeanTensor input values must always have the same shape. Expected shape (set during the first call): %s. Got %s", + this.shape.toString(), values.shape().toString())); + } + + Operand numValues = tf.onesLike(tValues); + if (tSampleWeights != null) { + // Update dimensions of weights to match with values if possible. + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(tf, null, tValues, tSampleWeights); + tValues = tuple.getTarget(); + tSampleWeights = tuple.getSampleWeights(); + try { + // Broadcast weights if possible. + tSampleWeights = WeightsBroadcastOps.broadcastWeights(tf, tSampleWeights, tValues); + } catch (IllegalArgumentException ex) { + // sampleWeights cannot be broadcast to values + // Reduce values to same ndim as weight array + int ndim = values.shape().numDimensions(); + int weightNdim = tSampleWeights.asOutput().shape().numDimensions(); + int[] range = new int[ndim - weightNdim]; + for (int i = weightNdim; i < ndim; i++) { + range[i] = i; + } + tValues = tf.math.mean(tValues, tf.constant(range)); + } + numValues = tf.math.mul(numValues, tSampleWeights); + tValues = tf.math.mul(tValues, tSampleWeights); + } + + List controlOpsPre = new ArrayList<>(); + if (needsInitialization) { + controlOpsPre.add(countInitializer); + controlOpsPre.add(totalInitializer); + } + Ops tf1 = tf.withSubScope("variables").withControlDependencies(controlOpsPre); + + List controlOps = new ArrayList<>(); + controlOps.add(tf1.assignAdd(this.count, numValues)); + controlOps.add(tf1.assignAdd(this.total, tValues)); + return controlOps; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + if (!this.initialized) { + throw new IllegalStateException( + "MeanTensor does not have any result yet. Please use `.update_state(value)` before retrieving the result."); + } + return getTF().math.divNoNan(total, count); + } + + /** @return the total */ + public Variable getTotal() { + return total; + } + + /** @return the count */ + public Variable getCount() { + return count; + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + List controlOpsPre = new ArrayList<>(); + controlOpsPre.add(countInitializer); + controlOpsPre.add(totalInitializer); + return getTF().withSubScope("resetStates").withControlDependencies(controlOpsPre).noOp(); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java index 95b74bf1eea..a33750ac3f6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -15,15 +15,16 @@ package org.tensorflow.framework.metrics; import org.tensorflow.Operand; -import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -/** Helper class with built-in metrics functions. */ -public class Metrics { +import static org.tensorflow.framework.utils.CastHelper.cast; - public static final float L2_NORM_EPSILON = 1e-12f; +/** Static methods for computing metrics. */ +public class Metrics { /** * Computes how often targets are in the top K predictions. @@ -49,10 +50,62 @@ public class Metrics { */ public static Operand topKCategoricalAccuracy( Ops tf, Operand labels, Operand predictions, long k) { - Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); - return CastHelper.cast( + Operand fPredictions = cast(tf, predictions, TFloat32.class); + return cast( tf, tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), predictions.type()); } + + /** + * Computes how often integer targets are in the top K predictions. + * + *

Standalone usage: + * + *

+   *     Operand<TInt32> labels = tf.constant(new int[]{2, 1});
+   *     Operand<TFloat32> predictions = tf.constant(new float[][]
+   *                            {{0.1f, 0.9f, 0.f8}, {0.05f, 0.95f, 0f}});
+   *     Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
+   *                                    labels, predictions, 3)
+   *     //m.shape().toString == "[2]"
+   * 
+ * + * @param tf the TensorFlow Ops. + * @param labels the ground truth values. + * @param predictions The prediction values. + * @param k Number of top elements to look at for computing accuracy. + * @param the data type for the predictions and results + * @param the data type ofr the labels. + * @return the Operand for the Sparse top K categorical accuracy value. + */ + @SuppressWarnings("unchecked") + public static Operand sparseTopKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, int k) { + Operand tLabels = cast(tf, labels, predictions.type()); + + // Flatten predictions to (batch_size, num_samples) and labels to (num_samples,) + + int predictionsRank = predictions.shape().numDimensions(); + int labelsRank = tLabels.shape().numDimensions(); + + Operand castPredictions = cast(tf, predictions, TFloat32.class); + if (predictionsRank != Shape.UNKNOWN_SIZE && labelsRank != Shape.UNKNOWN_SIZE) { + if (predictionsRank > 2) { + // y_pred = array_ops.reshape(y_pred, [-1, y_pred.shape[-1]]) + castPredictions = + tf.reshape( + castPredictions, + tf.constant(castPredictions.shape().takeLast(1).prepend(Shape.UNKNOWN_SIZE))); + } + if (labelsRank > 1) { + // y_true = array_ops.reshape(y_true, [-1]) + tLabels = tf.reshape(tLabels, tf.constant(Shape.of(Shape.UNKNOWN_SIZE))); + } + } + return cast( + tf, + tf.nn.inTopK(castPredictions, cast(tf, tLabels, TInt32.class), tf.constant(k)), + predictions.type()); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java index 422fd4808ff..2e4bde8ec55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -44,7 +44,13 @@ public Poisson(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the Poisson loss between labels and predictions. + * + * @param labels the truth values or labels, shape = {@code [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Poisson loss value, shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java new file mode 100644 index 00000000000..3812e799b75 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java @@ -0,0 +1,420 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the precision of the predictions with respect to the labels. + * + *

The metric creates two local variables, {@code truePositives} and {@code falsePositives + * } that are used to compute the precision. This value is ultimately returned as precision, + * an idempotent operation that simply divides {@code truePositives} by the sum of {@code + * truePositives} and {@code falsePositives}. + * + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of + * 0 to mask values. + * + *

If {@code topK} is set, the metric calculates precision as how often on average a class + * among the top-k classes with the highest predicted values of a batch entry is correct and can be + * found in the label for that entry. + * + *

If {@code classId} is specified, the metric calculates precision by considering only the + * entries in the batch for which {@code classId} is above the {@code thresholds} and/or + * in the top-k highest predictions, and computing the fraction of them for which {@code classId + * } is indeed a correct label. + * + * @param The data type for the metric result + */ +public class Precision extends Metric { + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; + public static final float DEFAULT_THRESHOLD = 0.5f; + + private final float[] thresholds; + private final Integer topK; + private final Integer classId; + private final String truePositivesName; + private final String falsePositivesName; + private final Class type; + private final List initializers = new ArrayList<>(); + private Variable truePositives; + private Variable falsePositives; + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId + * values and with a threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, long seed, Class type) { + this(tf, null, null, null, null, seed, type); + } + + /** + * Creates a Precision Metric with no topK or classId values with a threshold of {@link + * #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, String name, long seed, Class type) { + this(tf, name, null, null, null, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId + * values. + * + * @param tf the TensorFlow Ops + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} and no topK or classId + * values. + * + * @param tf the TensorFlow Ops + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, null, null, seed, type); + } + + /** + * Creates a Precision Metric with no topK or classId values. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, null, null, seed, type); + } + + /** + * Creates a Precision Metric with no topK or classId values. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision(Ops tf, String name, float[] thresholds, long seed, Class type) { + this(tf, name, thresholds, null, null, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, float threshold, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Precision Metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, thresholds, topK, classId, seed, type); + } + + /** + * Creates a Precision Metric. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold Optional threshold value in the range {@code [0, 1]}. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, + String name, + float threshold, + Integer topK, + Integer classId, + long seed, + Class type) { + this(tf, name, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Precision Metric. + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds Optional threshold values in the range {@code [0, 1]}. A threshold is + * compared with prediction values to determine the truth value of predictions (i.e., above + * the threshold is true, below is false). One metric value is generated for each threshold + * value. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Precision( + Ops tf, + String name, + float[] thresholds, + Integer topK, + Integer classId, + long seed, + Class type) { + super(tf, name, seed); + this.type = type; + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); + float defaultThreshold = topK == null ? DEFAULT_THRESHOLD : MetricsHelper.NEG_INF; + this.thresholds = thresholds == null ? new float[] {defaultThreshold} : thresholds; + this.topK = topK; + this.classId = classId; + + init(); + } + + /** Initializes the variables */ + private void init() { + Ops tf = getTF(); + Zeros zeros = new Zeros<>(tf); + Operand zero = zeros.call(tf.constant(Shape.of(thresholds.length)), type); + + if (this.truePositives == null) { + this.truePositives = tf.withName(truePositivesName).variable(zero); + initializers.add(tf.assign(truePositives, zero)); + } + if (this.falsePositives == null) { + this.falsePositives = + tf.withName(falsePositivesName) + .variable(zero); + initializers.add(tf.assign(falsePositives, zero)); + } + } + + /** + * Accumulates true positive and false positive statistics. + * + * @param labels the labels The ground truth values, with the same dimensions as predictions. Will + * be cast to {@link TBool}. + * @param predictions the predictions, each element must be in the range {@code [0, 1]}. + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Ops tf = getTF(); + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, truePositives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, falsePositives); + + Operand tPredictions = cast(tf, predictions, type); + Operand tLabels = cast(tf, labels, type); + Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; + + return new ArrayList( + MetricsHelper.updateConfusionMatrixVariables( + tf, + confusionMatrix, + Collections.EMPTY_MAP, + tLabels, + tPredictions, + tf.constant(thresholds), + topK, + classId, + tSampleWeights, + false, + null)); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + Operand result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + return thresholds.length == 1 + ? tf.reshape(tf.slice( + result, + tf.expandDims(tf.constant(0), tf.constant(0)), + tf.expandDims(tf.constant(1), tf.constant(0))), + tf.constant(Shape.scalar())) + : result; + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); + } + + /** + * Gets the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return thresholds; + } + + /** + * Gets the topK value, may be null + * + * @return the topK value or null + */ + public Integer getTopK() { + return topK; + } + + /** + * Gets the classId, may be null + * + * @return the classId or null + */ + public Integer getClassId() { + return classId; + } + + /** + * Gets the truePositives variable + * + * @return the truePositives + */ + public Variable getTruePositives() { + return truePositives; + } + + /** + * Gets the falsePositives variable + * + * @return the falsePositives + */ + public Variable getFalsePositives() { + return falsePositives; + } + + /** + * Gets the name of the truePositives variable + * + * @return the truePositivesName + */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** + * Gets the name of the falsePositives variable + * + * @return the falsePositivesName + */ + public String getFalsePositivesName() { + return falsePositivesName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java new file mode 100644 index 00000000000..5f5f9b47a10 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/PrecisionAtRecall.java @@ -0,0 +1,137 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best precision where recall is >= specified value. + * + *

This metric creates four local variables, truePositives, trueNegatives, falsePositives and + * falseNegatives that are used to compute the precision at the given recall. The threshold for the + * given recall value is computed and used to evaluate the corresponding precision. + * + *

If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of + * 0 to mask values. + * + * @param The data type for the metric result + */ +public class PrecisionAtRecall extends SensitivitySpecificityBase { + + private final float recall; + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()} and {@link + * #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param recall the recall. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. + */ + public PrecisionAtRecall(Ops tf, float recall, long seed, Class type) { + this(tf, null, recall, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number of + * thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param recall the recall. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. + */ + public PrecisionAtRecall(Ops tf, String name, float recall, long seed, Class type) { + this(tf, name, recall, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param recall the recall. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. + */ + public PrecisionAtRecall(Ops tf, float recall, int numThresholds, long seed, Class type) { + this(tf, null, recall, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param recall the recall. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. + */ + public PrecisionAtRecall( + Ops tf, String name, float recall, int numThresholds, long seed, Class type) { + super(tf, name, numThresholds, seed, type); + if (recall < 0f || recall > 1f) + throw new IllegalArgumentException("recall must be in the range [0, 1]."); + this.recall = recall; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + + Operand div = + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + Operand sub = tf.math.sub(div, cast(tf, tf.constant(recall), getType())); + Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); + minIndex = tf.expandDims(minIndex, tf.constant(0)); + + Operand trueSlice = tf.slice(truePositives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(falsePositives, minIndex, tf.constant(new int[] {1})); + return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + } + + /** + * Gets the recall value + * + * @return the recall value + */ + public float getRecall() { + return recall; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java new file mode 100644 index 00000000000..3886ec050b0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Recall.java @@ -0,0 +1,440 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the recall of the predictions with respect to the labels. + * + *

This metric creates two local variables, {@code truePositives} and {@code falseNegatives + * }, that are used to compute the recall. This value is ultimately returned as recall, an + * idempotent operation that simply divides {@code truePositives} by the sum of {@code + * truePositives} and {@code falseNegatives}. + * + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sampleWeights of + * 0 to mask values. + * + *

If {@code topK} is set, the metric calculates recall as how often on average a class + * among the labels of a batch entry is in the top-k predictions. + * + *

If {@code classId} is specified, the metric calculates recall by considering only the + * entries in the batch for which {@code classId} is in the label, and computing the fraction + * of them for which {@code classId} is above the threshold and/or in the top-k predictions. + * + * @param The data type for the metric result + */ +public class Recall extends Metric { + public static final float DEFAULT_THRESHOLD = 0.5f; + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; + + private final float[] thresholds; + private final Integer topK; + private final Integer classId; + private final String truePositivesName; + private final String falseNegativesName; + private final Class type; + private final List initializers = new ArrayList<>(); + private Variable truePositives; + private Variable falseNegatives; + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set + * to null, and thresholds set to {@link #DEFAULT_THRESHOLD} + * + * @param tf The TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, long seed, Class type) { + this(tf, null, null, null, null, seed, type); + } + + /** + * Creates a Recall metric with topK and classId set to null and thresholds set to {@link + * #DEFAULT_THRESHOLD}. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, long seed, Class type) { + this(tf, name, null, null, null, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set + * to null. + * + * @param tf The TensorFlow Ops + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, float threshold, long seed, Class type) { + this(tf, null, threshold, null, null, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()}, and topK and classId set + * to null. + * + * @param tf The TensorFlow Ops + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, null, null, seed, type); + } + + /** + * Creates a Recall metric with topK and classId set to null. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, threshold, null, null, seed, type); + } + + /** + * Creates a Recall metric with topK and classId set to null. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, float[] thresholds, long seed, Class type) { + this(tf, name, thresholds, null, null, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()} and using a threshold + * value of {@link #DEFAULT_THRESHOLD}. + * + * @param tf The TensorFlow Ops + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, null, topK, classId, seed, type); + } + + /** + * Creates a Recall metric using a threshold value of {@link #DEFAULT_THRESHOLD}. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, String name, Integer topK, Integer classId, long seed, Class type) { + this(tf, name, null, topK, classId, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()} + * + * @param tf The TensorFlow Ops + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall(Ops tf, float threshold, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Recall metric with a name of {@link Class#getSimpleName()} + * + * @param tf The TensorFlow Ops + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall( + Ops tf, float[] thresholds, Integer topK, Integer classId, long seed, Class type) { + this(tf, null, thresholds, topK, classId, seed, type); + } + + /** + * Creates a Recall metric. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param threshold A threshold is compared with prediction values to determine the truth value of + * predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults to + * {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall( + Ops tf, + String name, + float threshold, + Integer topK, + Integer classId, + long seed, + Class type) { + this(tf, name, new float[] {threshold}, topK, classId, seed, type); + } + + /** + * Creates a Recall metric. + * + * @param tf The TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param thresholds A threshold is compared with prediction values to determine the truth value + * of predictions (i.e., above the threshold is `true`, below is `false`). If null, defaults + * to {@link #DEFAULT_THRESHOLD}. + * @param topK An optional value specifying the top-k predictions to consider when calculating + * precision. + * @param classId Optional Integer class ID for which we want binary metrics. This must be in the + * half-open interval [0, numClasses], where numClasses is the last dimension of predictions. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public Recall( + Ops tf, + String name, + float[] thresholds, + Integer topK, + Integer classId, + long seed, + Class type) { + super(tf, name, seed); + this.type = type; + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); + float defaultThreshold = topK == null ? DEFAULT_THRESHOLD : MetricsHelper.NEG_INF; + + this.thresholds = thresholds == null ? new float[] {defaultThreshold} : thresholds; + this.topK = topK; + this.classId = classId; + + init(); + } + + /** Initializes the Variables */ + private void init() { + Ops tf = getTF(); + Zeros zeros = new Zeros<>(tf); + Operand zero = zeros.call(tf.constant(Shape.of(this.thresholds.length)), type); + if (truePositives == null) { + + truePositives = tf.withName(truePositivesName).variable(zero); + initializers.add(tf.assign(truePositives, zero)); + } + + if (this.falseNegatives == null) { + + falseNegatives = tf.withName(falseNegativesName).variable(zero); + initializers.add(tf.assign(falseNegatives, zero)); + } + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return getTF().withSubScope("resetStates").withControlDependencies(initializers).noOp(); + } + + /** + * Accumulates true positive and false negative statistics. + * + * @param labels the labels The ground truth values, with the same dimensions as predictions. Will + * be cast to {@link TBool}. + * @param predictions the predictions, each element must be in the range {@code [0, 1]}. + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Ops tf = getTF(); + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, this.truePositives); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, this.falseNegatives); + + Operand tPredictions = cast(tf, predictions, type); + Operand tLabels = cast(tf, labels, type); + Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; + + return MetricsHelper.updateConfusionMatrixVariables( + tf, + confusionMatrix, + Collections.EMPTY_MAP, + tLabels, + tPredictions, + tf.constant(thresholds), + topK, + classId, + tSampleWeights, + false, + null); + } + + @Override + public Operand result() { + Ops tf = getTF(); + Operand result = + tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); + return this.thresholds.length == 1 + ? tf.slice(result, tf.constant(new int[] {0}), tf.constant(new int[1])) + : result; + } + + /** + * Gets the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return this.thresholds; + } + + /** + * Gets the topK value + * + * @return the topK value + */ + public Integer getTopK() { + return this.topK; + } + + /** + * Gets the class id + * + * @return the class id + */ + public Integer getClassId() { + return this.classId; + } + + /** + * Gets the truePositives variable + * + * @return the truePositives variable + */ + public Variable getTruePositives() { + return this.truePositives; + } + + /** + * Gets the falseNegatives variable + * + * @return the falseNegatives variable + */ + public Variable getFalseNegatives() { + return this.falseNegatives; + } + + /** + * Gets the truePositives variable name + * + * @return the truePositives variable name + */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** + * Gets the falseNegatives variable name + * + * @return the falseNegatives variable name + */ + public String getFalseNegativesName() { + return falseNegativesName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java new file mode 100644 index 00000000000..a3fc2f77b7f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RecallAtPrecision.java @@ -0,0 +1,147 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Where; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best recall where precision is >= specified value. + * + *

For a given score-label-distribution the required precision might not be achievable, in this + * case 0.0 is returned as recall. + * + *

This metric creates four local variables, truePositives, trueNegatives, falsePositives and + * falseNegatives that are used to compute the recall at the given precision. The threshold for the + * given precision value is computed and used to evaluate the corresponding recall. + * + *

If {@code sampleWeights} is null, weights default to 1. Use {@code sampleWeights} of + * 0 to mask values. + * + * @param The data type for the metric result + */ +public class RecallAtPrecision extends SensitivitySpecificityBase { + + private final float precision; + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()} and {@link + * #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param precision the precision. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. + */ + public RecallAtPrecision(Ops tf, float precision, long seed, Class type) { + this(tf, null, precision, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number of + * thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric. If null, defaults to {@link Class#getSimpleName()} + * @param precision the precision. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. + */ + public RecallAtPrecision(Ops tf, String name, float precision, long seed, Class type) { + this(tf, name, precision, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param precision the precision. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. + */ + public RecallAtPrecision(Ops tf, float precision, int numThresholds, long seed, Class type) { + this(tf, null, precision, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param precision the precision. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if recall is not in the range + * [0-1]. + */ + public RecallAtPrecision( + Ops tf, String name, float precision, int numThresholds, long seed, Class type) { + super(tf, name, numThresholds, seed, type); + if (precision < 0f || precision > 1f) + throw new IllegalArgumentException("recall must be in the range [0, 1]."); + this.precision = precision; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + + Operand precisions = + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); + Operand recalls = + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + Operand isFeasible = + tf.math.greaterEqual(precisions, cast(tf, tf.constant(precision), getType())); + Where feasible = tf.where(isFeasible); + Operand feasibleExists = tf.math.greater(tf.size(feasible), tf.constant(0)); + + Operand gather = tf.expandDims(tf.gather(recalls, feasible, tf.constant(0)), tf.constant(0)); + return tf.select( + feasibleExists, + tf.reduceMax(gather, allAxes(tf, gather)), + cast(tf, tf.constant(0), getType())); + } + + /** + * Gets the precision + * + * @return the precision + */ + public float getPrecision() { + return precision; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java new file mode 100644 index 00000000000..3886428425b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/RootMeanSquaredError.java @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes root mean squared error metric between {@code labels} and {@code predictions} + * . + * + * @param The data type for the metric result + */ +public class RootMeanSquaredError extends Mean { + + /** + * Creates a RootMeanSquaredError metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public RootMeanSquaredError(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a RootMeanSquaredError metric + * + * @param tf the TensorFlow Ops + * @param name name of the metric instance. If null, name defaults to {@link + * Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public RootMeanSquaredError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + } + + /** + * Accumulates root mean squared error statistics. + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or * + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Operand tSampleWeights = + sampleWeights != null ? cast(getTF(), sampleWeights, getResultType()) : null; + + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(getTF(), tLabels, tPredictions); + tPredictions = ops.getTarget(); + tLabels = ops.getLabels(); + + Operand errorSquared = + cast(getTF(), getTF().math.squaredDifference(tPredictions, tLabels), getResultType()); + + return super.updateStateList(errorSquared, tSampleWeights); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + return getTF().math.sqrt(getTF().math.divNoNan(this.total, this.count)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java new file mode 100644 index 00000000000..29c0504b823 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SensitivityAtSpecificity.java @@ -0,0 +1,147 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best sensitivity where sensitivity is >= specified value. + * + *

{@code Sensitivity} measures the proportion of actual positives that are correctly + * identified as such {@code (tp / (tp + fn))}. + * + *

{@code Specificity} measures the proportion of actual negatives that are correctly + * identified as such {@code (tn / (tn + fp))}. + * + *

This metric creates four local variables, {@code truePositives}, {@code trueNegatives + * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the + * sensitivity at the given specificity. The threshold for the given specificity value is computed + * and used to evaluate the corresponding sensitivity. + * + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of + * 0 to mask values. + * + * @see Additional information + * about specificity and sensitivity + * @param The data type for the metric result + */ +public class SensitivityAtSpecificity extends SensitivitySpecificityBase { + + private final float specificity; + + /** + * Creates a SpecificityAtSensitivity metric with a name of {@link Class#getSimpleName()} and + * {@link #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param specificity the specificity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity(Ops tf, float specificity, long seed, Class type) { + this(tf, null, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a SpecificityAtSensitivity metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number + * of thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param specificity the specificity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity( + Ops tf, String name, float specificity, long seed, Class type) { + this(tf, name, specificity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param specificity the specificity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * specificity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity( + Ops tf, float specificity, int numThresholds, long seed, Class type) { + this(tf, null, specificity, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param specificity the specificity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * specificity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if specificity is not in the range + * [0-1]. + */ + public SensitivityAtSpecificity( + Ops tf, String name, float specificity, int numThresholds, long seed, Class type) { + super(tf, name, numThresholds, seed, type); + if (specificity < 0f || specificity > 1f) + throw new IllegalArgumentException("specificity must be in the range [0, 1]."); + this.specificity = specificity; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Ops tf = getTF(); + Operand specificities = + tf.math.divNoNan(trueNegatives, tf.math.add(trueNegatives, falsePositives)); + Operand sub = tf.math.sub(specificities, cast(tf, tf.constant(specificity), getType())); + Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); + minIndex = tf.expandDims(minIndex, tf.constant(0)); + + Operand trueSlice = tf.slice(truePositives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(falseNegatives, minIndex, tf.constant(new int[] {1})); + return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + } + + /** + * Gets the specificity + * + * @return the specificity + */ + public float getSpecificity() { + return specificity; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java new file mode 100644 index 00000000000..5294f798044 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalAccuracy.java @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.op.math.Equal; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import java.util.Collections; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Calculates how often predictions matches integer labels. + * + *

You can provide logits of classes as {@code predictions}, since argmax of logits and + * probabilities are same. + * + *

This metric creates two local variables, `total` and `count` that are used to compute the + * frequency with which {@code predictions} matches {@code labels}. This frequency is + * ultimately returned as `sparse categorical accuracy`: an idempotent operation that simply divides + * `total` by `count`. + * + *

If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0 to mask values.' + * + *

Usage: + * + *

+ * SparseCategoricalAccuracy m = new org.tensorflow.framework.metrcis.SparseCategoricalAccuracy();
+ * m.update_state(tf.constant(new float[][] {{2}, {1}},
+ *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, [{0.05f, 0.95f, 0f}});
+ * Operand<TFloat32> result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 0.5
+ * 
+ * + *
+ * m.reset_states()
+ * m.update_state(
+ *     tf.constant(new float[][] {{2}, {1}},
+ *     tf.constant(new float[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}},
+ *     tf.constant(new float[] {0.7f, 0.3f});
+ * result = m.result();
+ * System.out.println(result.data().getFloat());
+ * 0.3
+ * 
+ * + * @param The data type for the metric result + */ +public class SparseCategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a SparseCategoricalAccuracy metric, using name of {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result + */ + public SparseCategoricalAccuracy(Ops tf, long seed, Class type) { + this(tf, null, seed, type); + } + + /** + * Creates a SparseCategoricalAccuracy metric. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null use {@link Class#getSimpleName()} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type of the metric result. + */ + public SparseCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + super.setLoss(this); + } + + /** + * Calculates how often predictions matches integer labels. + * + * @param labels Integer ground truth values. + * @param predictions the predictions + * @return Sparse categorical accuracy values. + */ + @Override + public Operand call( + Operand labels, Operand predictions) { + + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + Shape predShape = predictions.asOutput().shape(); + Shape labelsShape = labels.asOutput().shape(); + long predictionsRank = predShape.numDimensions(); + long labelsRank = labelsShape.numDimensions(); + + // If the shape of labels is (num_samples, 1), squeeze to (num_samples,) + if (predictionsRank != Shape.UNKNOWN_SIZE + && labelsRank != Shape.UNKNOWN_SIZE + && labelsShape.size((int) labelsRank - 1) == 1) { + tLabels = getTF().squeeze(tLabels, Squeeze.axis(Collections.singletonList(labelsRank - 1L))); + } + Operand argMaxPred = + cast( + getTF(), + getTF().math.argMax(tPredictions, getTF().constant(-1L), TInt64.class), + getResultType()); + + Equal equals = getTF().math.equal(tLabels, argMaxPred); + return getTF().dtypes.cast(equals, getResultType()); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java index e954169b2af..04555d85b66 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -27,6 +27,9 @@ * A metric that computes the sparse categorical cross-entropy loss between true labels and * predicted labels. * + *

You can provide logits of classes as predictions, since argmax of logits and probabilities are + * same. + * * @param The data type for the metric result. */ public class SparseCategoricalCrossentropy extends MeanMetricWrapper @@ -55,7 +58,13 @@ public SparseCategoricalCrossentropy( this.axis = axis; } - /** {@inheritDoc} */ + /** + * Calculates how often predictions matches integer labels. + * + * @param labels Integer ground truth values. + * @param predictions the predictions + * @return Sparse categorical accuracy values. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java new file mode 100644 index 00000000000..29dc91298d3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseTopKCategoricalAccuracy.java @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes how often integer targets are in the top `K` predictions. + * + * @param The data type for the metric result + */ +public class SparseTopKCategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a SparseTopKCategoricalAccuracy metric using {@link #DEFAULT_K} for the number of top + * elements. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a SparseTopKCategoricalAccuracy metric. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the date type for the result + */ + public SparseTopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** + * Computes how often integer targets are in the top {@code K} predictions. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @return Sparse top K categorical accuracy value. + */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Metrics.sparseTopKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java new file mode 100644 index 00000000000..2cb7e54eba0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SpecificityAtSensitivity.java @@ -0,0 +1,148 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.SensitivitySpecificityBase; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes best specificity where sensitivity is >= specified value. {@code Sensitivity} + * measures the proportion of actual positives that are correctly identified as such {@code + * (tp / (tp + fn))}. + * + *

{@code Specificity} measures the proportion of actual negatives that are correctly + * identified as such {@code (tn / (tn + fp))}. + * + *

This metric creates four local variables, {@code truePositives}, {@code trueNegatives + * }, {@code falsePositives} and {@code falseNegatives} that are used to compute the + * specificity at the given sensitivity. The threshold for the given sensitivity value is computed + * and used to evaluate the corresponding specificity. + * + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use sample_weight of + * 0 to mask values. + * + * @see Additional information + * about specificity and sensitivity + * @param The data type for the metric result + */ +public class SpecificityAtSensitivity extends SensitivitySpecificityBase { + + private final float sensitivity; + + /** + * Creates a SpecificityAtSensitivity metric with a name of {@link Class#getSimpleName()} and + * {@link #DEFAULT_NUM_THRESHOLDS} for the number of thresholds + * + * @param tf The TensorFlow Ops + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity(Ops tf, float sensitivity, long seed, Class type) { + this(tf, null, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a SpecificityAtSensitivity metric with {@link #DEFAULT_NUM_THRESHOLDS} for the number + * of thresholds + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity( + Ops tf, String name, float sensitivity, long seed, Class type) { + this(tf, name, sensitivity, DEFAULT_NUM_THRESHOLDS, seed, type); + } + + /** + * Creates a PrecisionRecall metric with a name of {@link Class#getSimpleName()}. + * + * @param tf The TensorFlow Ops + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * sensitivity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity( + Ops tf, float sensitivity, int numThresholds, long seed, Class type) { + this(tf, null, sensitivity, numThresholds, seed, type); + } + + /** + * Creates a PrecisionRecall metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric, if null defaults to {@link Class#getSimpleName()} + * @param sensitivity the sensitivity. A scalar value in range [0, 1] + * @param numThresholds Defaults to 200. The number of thresholds to use for matching the given + * sensitivity. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0 or if sensitivity is not in the range + * [0-1]. + */ + public SpecificityAtSensitivity( + Ops tf, String name, float sensitivity, int numThresholds, long seed, Class type) { + super(tf, name, numThresholds, seed, type); + if (sensitivity < 0f || sensitivity > 1f) + throw new IllegalArgumentException("sensitivity must be in the range [0, 1]."); + this.sensitivity = sensitivity; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + + Ops tf = getTF(); + + Operand sensitivities = + tf.math.divNoNan(truePositives, tf.math.add(truePositives, falseNegatives)); + Operand sub = tf.math.sub(sensitivities, cast(tf, tf.constant(sensitivity), getType())); + Operand minIndex = tf.math.argMin(tf.math.abs(sub), tf.constant(0), TInt32.class); + minIndex = tf.expandDims(minIndex, tf.constant(0)); + + Operand trueSlice = tf.slice(trueNegatives, minIndex, tf.constant(new int[] {1})); + Operand falseSlice = tf.slice(falsePositives, minIndex, tf.constant(new int[] {1})); + return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); + } + + /** + * Gets the sensitivity + * + * @return the sensitivity + */ + public float getSensitivity() { + return sensitivity; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java index 19b3b1d0ac4..e2ff208b8f5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -44,7 +44,15 @@ public SquaredHinge(Ops tf, String name, long seed, Class type) { setLoss(this); } - /** {@inheritDoc} */ + /** + * Computes the squared hinge loss between labels and predictions. + * + * @param labels The ground truth values. {@code labels} values are expected to be -1 or 1. If + * binary (0 or 1) labels are provided we will convert them to -1 or 1. shape = {@code + * [batch_size, d0, .. dN]}. + * @param predictions the predictions, shape = {@code [batch_size, d0, .. dN]}. + * @return Squared hinge loss values. shape = {@code [batch_size, d0, .. dN-1]}. + */ @Override public Operand call( Operand labels, Operand predictions) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java new file mode 100644 index 00000000000..637ca6cdd05 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Sum.java @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.Reduce; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Computes the (weighted) sum of the given values. + * + *

For example, if values is {@code [1, 3, 5, 7]} then the sum is {@code 16}. If the + * weights were specified as {@code [1, 1, 0, 0]}, then the sum would be {@code 4.} + * + *

This metric creates one variable, {@code total}, that is used to compute the sum of + * values. This is ultimately returned as sum. + * + *

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. + */ +public class Sum extends Reduce { + + /** + * Creates a Sum metric with a name of {@link Class#getSimpleName()} + * + * @param tf The TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Sum(Ops tf, long seed, Class type) { + super(tf, null, MetricReduction.SUM, seed, type); + } + + /** + * Creates a Sum metric. + * + * @param tf The TensorFlow Ops + * @param name the name of the metric instance. If null, defaults to {@link Class#getSimpleName()} + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Sum(Ops tf, String name, long seed, Class type) { + super(tf, name, MetricReduction.SUM, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java new file mode 100644 index 00000000000..0146552433f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracy.java @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Computes the poisson loss metric between labels and predictions. + * + * @param The data type for the metric result + */ +public class TopKCategoricalAccuracy extends MeanMetricWrapper + implements LossMetric { + public static final int DEFAULT_K = 5; + /** Number of top elements to look at for computing accuracy. */ + private final int k; + + /** + * Creates a TopKCategoricalAccuracy metric using {@link #DEFAULT_K} for {@code k}, Number of + * top elements to look at for computing accuracy. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result + */ + public TopKCategoricalAccuracy(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_K, seed, type); + } + + /** + * Creates a TopKCategoricalAccuracy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param k Number of top elements to look at for computing accuracy. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type The data type for the metric result + */ + public TopKCategoricalAccuracy(Ops tf, String name, int k, long seed, Class type) { + super(tf, name, seed, type); + this.k = k; + setLoss(this); + } + + /** + * Computes how often targets are in the top {@code K} predictions. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @return Top K categorical accuracy value. + */ + @Override + public Operand call( + Operand labels, Operand predictions) { + Operand tLabels = cast(getTF(), labels, getResultType()); + Operand tPredictions = cast(getTF(), predictions, getResultType()); + return Metrics.topKCategoricalAccuracy(getTF(), tLabels, tPredictions, k); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java new file mode 100644 index 00000000000..5c65f8c469f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TrueNegatives.java @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of true negatives. + * + *

If {@code sampleWeights} is given, calculates the sum of the weights of true negatives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of + * the number of true negatives. + * + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. + * + * @param The data type for the metric result + */ +public class TrueNegatives extends ConfusionMatrixConditionCount { + + /** + * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a TrueNegatives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a TrueNegatives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TrueNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a TrueNegatives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TrueNegatives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.TRUE_NEGATIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java new file mode 100644 index 00000000000..f0dd8c42de5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/TruePositives.java @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.ConfusionMatrixConditionCount; +import org.tensorflow.framework.metrics.impl.ConfusionMatrixEnum; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Metric that calculates the number of true positives. + * + *

If {@code sampleWeights} is given, calculates the sum of the weights of true positives. + * This metric creates one local variable, {@code accumulator} that is used to keep track of + * the number of true positives. + * + *

If {@code sampleWeights} is {@code null}, weights default to 1. Use {@code + * sampleWeights} of 0 to mask values. + * + * @param The data type for the metric result + */ +public class TruePositives extends ConfusionMatrixConditionCount { + + /** + * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name and a + * default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, long seed, Class type) { + this(tf, null, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, float threshold, long seed, Class type) { + this(tf, null, new float[] {threshold}, seed, type); + } + + /** + * Creates a TruePositives metric, using {@link Class#getSimpleName()} for the metric name + * + * @param tf the TensorFlow Ops + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, float[] thresholds, long seed, Class type) { + this(tf, null, thresholds, seed, type); + } + + /** + * Creates a TruePositives metric, using a default threshold of {@link #DEFAULT_THRESHOLD}. + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_THRESHOLD, seed, type); + } + + /** + * Creates a TruePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param threshold a threshold value in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, String name, float threshold, long seed, Class type) { + this(tf, name, new float[] {threshold}, seed, type); + } + + /** + * Creates a TruePositives metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param thresholds threshold values in the range {@code [0, 1]}. A threshold is compared + * with prediction values to determine the truth value of predictions (i.e., above the + * threshold is {@code true}, below is {@code false}). One metric value is generated + * for each threshold value + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public TruePositives(Ops tf, String name, float[] thresholds, long seed, Class type) { + super(tf, name, ConfusionMatrixEnum.TRUE_POSITIVES, thresholds, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java new file mode 100644 index 00000000000..88597cf85ec --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixConditionCount.java @@ -0,0 +1,196 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Abstract base class that calculates the value of the given confusion matrix condition based on + * labels and predictions. + * + * @param The data type for the metric result + */ +public abstract class ConfusionMatrixConditionCount extends Metric { + public static final String ACCUMULATOR = "accumulator"; + public static final float DEFAULT_THRESHOLD = 0.5f; + private final ConfusionMatrixEnum confusionMatrixCond; + private final float[] thresholds; + private final String accumulatorName; + private final Class type; + private Variable accumulator; + private Assign initializer; + + /** + * Creates a ConfusionMatrixConditionCount type of Metric, using a threshold of {@link + * #DEFAULT_THRESHOLD} + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param confusionMatrixCond the confusion matrix condition to calculate + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public ConfusionMatrixConditionCount( + Ops tf, String name, ConfusionMatrixEnum confusionMatrixCond, long seed, Class type) { + this(tf, name, confusionMatrixCond, DEFAULT_THRESHOLD, seed, type); + } + /** + * Creates a ConfusionMatrixConditionCount type of Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param confusionMatrixCond the confusion matrix condition to calculate + * @param threshold a threshold value in {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each + * threshold value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public ConfusionMatrixConditionCount( + Ops tf, + String name, + ConfusionMatrixEnum confusionMatrixCond, + float threshold, + long seed, + Class type) { + this(tf, name, confusionMatrixCond, new float[] {threshold}, seed, type); + } + + /** + * Creates a ConfusionMatrixConditionCount type of Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric, if null then {@link Class#getSimpleName()} is used + * @param confusionMatrixCond the confusion matrix condition to calculate + * @param thresholds threshold values in {@code [0, 1]}. A threshold is compared with + * prediction values to determine the truth value of predictions (i.e., above the threshold is + * {@code true}, below is {@code false}). One metric value is generated for each + * threshold value. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + */ + public ConfusionMatrixConditionCount( + Ops tf, + String name, + ConfusionMatrixEnum confusionMatrixCond, + float[] thresholds, + long seed, + Class type) { + super(tf, name, seed); + accumulatorName = this.getVariableName(ACCUMULATOR); + this.type = type; + this.confusionMatrixCond = confusionMatrixCond; + this.thresholds = thresholds; + init(); + } + + /** Initialize the metric */ + private void init() { + Shape variableShape = Shape.of(this.thresholds.length); + + Zeros zeros = new Zeros<>(getTF()); + accumulator = + getTF() + .withName(getAccumulatorName()) + .variable(zeros.call(getTF().constant(variableShape), type)); + initializer = getTF().assign(accumulator, zeros.call(getTF().constant(variableShape), type)); + } + + /** + * Gets the initializer for the accumulator variable + * + * @return the initializer for the accumulator variable + */ + public Assign getInitializer() { + return initializer; + } + + /** + * Accumulates the metric statistics. + * + * @param labels The ground truth values. + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ + @Override + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Ops tf = getTF(); + Operand tLabels = cast(tf, labels, type); + Operand tPredictions = cast(tf, predictions, type); + Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; + return new ArrayList<>( + MetricsHelper.updateConfusionMatrixVariables( + getTF(), + Collections.singletonMap(confusionMatrixCond, accumulator), + Collections.singletonMap(confusionMatrixCond, initializer), + tLabels, + tPredictions, + tf.constant(thresholds), + null, + null, + tSampleWeights, + false, + null)); + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + return getTF().identity(accumulator); + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return initializer; + } + + /** + * get the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return this.thresholds; + } + + /** @return the accumulatorName */ + public String getAccumulatorName() { + return accumulatorName; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java new file mode 100644 index 00000000000..caa5f203f9f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/ConfusionMatrixEnum.java @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +/** Enumerate the values for a confusion matrix. */ +public enum ConfusionMatrixEnum { + /** These are cases in which the prediction is true, and reality is true. */ + TRUE_POSITIVES("tp"), + /** These are cases in which the prediction is true, and reality is false. */ + FALSE_POSITIVES("fp"), + /** These are cases in which the prediction is false, and reality is false. */ + TRUE_NEGATIVES("tn"), + /** These are cases in which the prediction is false, and reality is true. */ + FALSE_NEGATIVES("fn"); + + private final String abbrev; + + /** + * Creates a ConfusionMatrixEnum + * + * @param abbrev the abbreviation for the confusion condition as required by the underlying + * TensorFlow api. + */ + ConfusionMatrixEnum(String abbrev) { + this.abbrev = abbrev; + } + + /** + * Gets the ConfusionMatrixEnum for this enum value, regardless of case. + * + * @param item either the name of the enumeration value or the abbreviation. + * @return ConfusionMatrixEnum for this enum value, or null if not found. + */ + public static ConfusionMatrixEnum get(String item) { + ConfusionMatrixEnum cm = valueOf(item.toUpperCase()); + if (cm == null) { + for (ConfusionMatrixEnum m : values()) { + if (m.getAbbreviation().equals(item.toLowerCase())) { + return m; + } + } + } + return null; + } + + /** + * Gets the abbreviation for this enum value + * + * @return the abbreviation for this enum value as required by the underlying TensorFlow api. + */ + public String getAbbreviation() { + return abbrev; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java index 1fb3d3bb580..f89047e457d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -25,7 +25,7 @@ public interface LossMetric { /** - * Calculates the weighted loss between labels and predictions + * Calculates the weighted loss between {@code labels} and {@code predictions} * * @param labels the truth values or labels * @param predictions the predictions diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java index 9a532a0294f..37bdd5849ae 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -29,8 +29,8 @@ * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of * {@link MetricReduction#WEIGHTED_MEAN}. * - *

The loss function calculates the loss between the labels and predictions - * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the + *

The loss function calculates the loss between the {@code labels} and {@code predictions + * } then passes this loss to the {@link Mean} metric to calculate the weighted mean of the * loss over many iterations or epochs * * @param The data type for the metric result diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 8a352322f52..40336233d21 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,20 +15,36 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; +import org.tensorflow.framework.utils.SparseTensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.op.core.Stack; +import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.TopK; import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -43,8 +59,8 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the sampleWeights can be broadcast to the same shape as values - * + * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values + * } * *

In losses and metrics, limited weight broadcasting is supported. Weights must be either * scalar, or the same rank as the target values, with each dimension either 1, or the same as the @@ -53,11 +69,11 @@ public class MetricsHelper { * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return Operation with control dependencies to ensure sampleWeight - * can be broadcast to values + * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} + * can be broadcast to {@code values} * @param the type of Operand - * @throws NotBroadcastableException If static checks determine sampleWeights has an - * incorrect shape that prohibit broadcasting to values + * @throws NotBroadcastableException If static checks determine {@code sampleWeights} has an + * incorrect shape that prohibit broadcasting to {@code values} */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -78,7 +94,7 @@ public static Op assertBroadcastable( && !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") - .withControlDependencies(Collections.EMPTY_LIST) + .withControlDependencies(java.util.Collections.EMPTY_LIST) .noOp(); } if (weightsRankStatic != valuesRankStatic) { @@ -88,8 +104,8 @@ public static Op assertBroadcastable( ASSERT_BROADCAST_ERROR_PREFIX, valuesRankStatic, weightsRankStatic, - valuesShapeStatic.toString(), - weightsShapeStatic.toString())); + valuesShapeStatic, + weightsShapeStatic)); } for (int i = 0; i < valuesRankStatic; i++) { @@ -100,8 +116,8 @@ public static Op assertBroadcastable( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, i, - valuesShapeStatic.toString(), - weightsShapeStatic.toString())); + valuesShapeStatic, + weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") @@ -185,13 +201,13 @@ private static Operand canBroadcastDims( } /** - * Broadcast weights to the same shape as values. + * Broadcast {@code weights} to the same shape as {@code values}. * * @param tf the TensorFlow ops - * @param weights Operand whose shape is broadcastable to values. + * @param weights Operand whose shape is broadcastable to {@code values}. * @param values Operand of any shape * @param the type of Operands - * @return weights broadcast to values shape + * @return {@code weights} broadcast to {@code values} shape */ public static Operand broadcastWeights( Ops tf, Operand weights, Operand values) { @@ -212,11 +228,473 @@ public static Operand broadcastWeights( return ctf.math.mul(weights, tf.onesLike(values)); } - // aliases for mean + /** + * Checks that all the Symbolic Shapes are consistent. + * + * @param tf the TensorFlow Ops + * @param symbols the list of Symbolic Shapes + * @param message the error message if the shapes are not consistent. + * @return a list of Operands to check the consistency of the symbolic shapes ready to add to a + * control dependency. + */ + public static List assertShapes( + Ops tf, List> symbols, String message) { + List updateOperations = new ArrayList<>(); + // check that the symbolic shape rank matches the operands rank. + symbols.forEach( + symbol -> { + Operand operand = symbol.getOperand(); + int rank = symbol.rank(); + Rank tfRank = tf.rank(operand); + Op assertion = + tf.withSubScope("assertShapes-1") + .assertThat( + tf.math.equal(tfRank, tf.constant(rank)), + Collections.singletonList(tf.constant(message))); + updateOperations.add(assertion); + }); + + Map> dict = new HashMap<>(); + + // check that each operand's dimension size equals the corresponding symbolic shape's dimensions + // size + symbols.forEach( + symbol -> { + AtomicLong ll = new AtomicLong(); + symbol + .getSymbols() + .forEach( + s -> { + Operand size = dict.get(s); + if (size == null) { + // save size for later checks + size = + tf.shape.size(symbol.getOperand(), tf.constant(ll.get()), TInt64.class); + dict.put(s, size); + } + Op assertion = + tf.withSubScope("assertShapes-2") + .assertThat( + tf.math.equal( + tf.shape.size( + symbol.getOperand(), + tf.constant(ll.getAndIncrement()), + TInt64.class), + size), + Collections.singletonList(tf.constant(message))); + updateOperations.add(assertion); + }); + }); + + return updateOperations; + } /** - * Calculate the mean of the operand, along all axes and keepDims is false - * + * Returns an op to update the given confusion matrix variables. + * + *

For every pair of values in {@code labels} and {@code predictions}: + * + *

+   * TRUE_POSITIVES:  {@code labels} == true and {@code predictions} > thresholds
+   * FALSE_POSITIVES: {@code labels} == true and {@code predictions} <= thresholds
+   * TRUE_NEGATIVES:  {@code labels} == false and {@code predictions} <= thresholds
+   * FALSE_NEGATIVE:  {@code labels} == false and {@code predictions} > thresholds
+   * 
+ * + *

The results will be weighted and added together. When multiple thresholds are provided, we + * will repeat the same for every threshold. + * + *

For estimation of these metrics over a stream of data, the function creates an `update_op` + * operation that updates the given variables. + * + *

{@code labels}, {@code predictions}, and {@code sampleWeight} tensors are + * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code + * sampleWeight} is then broadcast to the shape of {@code predictions}. + * + * @param tf the TensorFlow Ops + * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding variables to update as values. If {@code multiLabel}, then the variable + * shapes are (T, D), where T is the number of thresholds and D is the number of classes + * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then + * the variable shapes are (T). + * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding initializer Operands to for {@code variablesToUpdate}. + * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the + * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra + * dimension of size 1 that would be squeezed. + * @param predictions the predictions shape (N, Cx, P1?) + * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when + * topK is set + * @param topK optional, indicates that only the top k predictions should be considered. Applied + * before possibly slicing by {@code classIndex}. + * @param classIndex optional, limits the prediction and labels to the specified class. This is an + * integer index into the first dimension of Cx. + * @param sampleWeight optional {@code Tensor} that is aligned with labels and predictions as + * explained above. Use weights of 0 to mask values. + * @param multiLabel indicates whether multidimensional prediction/labels should be treated as + * multilabel responses, or flattened into a single label. When true, the values of {@code + * variablesToUpdate} must have a second dimension equal to the number of labels and + * predictions per example, and those tensors must not be RaggedTensors. + * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied + * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES + * without explicit multilabel handling (i.e. when the data is to be flattened). Must have + * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex + * } is provided, then the final dimension of Dx is 1. These weights will be broadcast + * across the 0th dimension (the examples dimension) of {@code predictions}. May be null. + * Must be null if {@code multiLabel}. + * @param the data type for the variables + * @throws IllegalArgumentException If {@code predictions} and {@code labels} have + * mismatched shapes, or if {@code sampleWeight} is not null and its shape + * doesn't match {@code predictions}, or if {@code multiLabel && labelWeights != null}.. + * @return an op to update the given confusion matrix variables. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static List updateConfusionMatrixVariables( + Ops tf, + Map> variablesToUpdate, + Map> varInitializers, + Operand labels, + Operand predictions, + Operand thresholds, + Integer topK, + Integer classIndex, + Operand sampleWeight, + boolean multiLabel, + Operand labelWeights) { + if (multiLabel && labelWeights != null) + throw new IllegalArgumentException( + "labelWeights for multilabel data should be handled outside of updateConfusionMatrixVariables when multiLabel is true."); + + if (variablesToUpdate == null || variablesToUpdate.isEmpty()) { + return Collections.EMPTY_LIST; + } + + Operand tLabels = labels; + Operand tPredictions = predictions; + Operand tSampleWeight = sampleWeight; + + // We will tile data for threshold comparisons. We want a cross product of thresholds and + // predictions/labels: + // In the multilabel case, we want a data shape of (T, N, D). + // else (T, ND). + // where + // T is numThresholds (the size of the 0th dimension of thresholds) + // N is the number of examples (the 0th dimension of labels and predictions) + // Dx == Cx except that if classIndex != null, + // then the last dimension of Dx is size 1 + // D is the product of all Dx + // ND is N * D + + // size of the 0th dimension of thresholds + // reshape to scalar for operations later. + Operand numThresholds = + tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); + + // if multilabel, then (rank(thresholds) == 1) + // else true + Operand oneThresh; + if (multiLabel) { + oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds)); + } else { + // TODO handle Ragged Tensors???? + // [y_pred, + // y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], + // sampleWeights) + oneThresh = tf.constant(true); + } + + List controlOps = new ArrayList<>(); + Operand axes = allAxes(tf, tPredictions); + controlOps.add( + tf.withSubScope("updateConfusionMatrixVariables-1") + .assertThat( + tf.reduceAll( + tf.math.greaterEqual( + tPredictions, cast(tf, tf.constant(0), tPredictions.type())), + axes), + Collections.singletonList(tf.constant("predictions must be >= 0")))); + controlOps.add( + tf.withSubScope("updateConfusionMatrixVariables-2") + .assertThat( + tf.reduceAll( + tf.math.lessEqual(tPredictions, cast(tf, tf.constant(1), tPredictions.type())), + axes), + Collections.singletonList(tf.constant("predictions must be <= 1")))); + + LossTuple result = + LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions, tSampleWeight); + tPredictions = result.getTarget(); // shape (N, Cx) + tLabels = result.getLabels(); // shape (N, Cx) + tSampleWeight = result.getSampleWeights(); // broadcastable to (N, Dx) + + if (!tPredictions.shape().isCompatibleWith(tLabels.shape())) + throw new IllegalArgumentException( + String.format( + "Shapes %s and %s are incompatible)", + tPredictions.shape().toString(), tLabels.shape().toString())); + + if (topK != null) { + tPredictions = filterTopK(tf, tPredictions, topK); + } + + if (classIndex != null) { + // Slice to new shapes (N, Dx) + tLabels = tf.squeeze(tf.gather(tLabels, + tf.constant(new int[] {classIndex}), tf.constant(-1)), + Squeeze.axis(Collections.singletonList(1L))); + tPredictions = tf.squeeze(tf.gather(tPredictions, + tf.constant(new int[] {classIndex}), tf.constant(-1)), + Squeeze.axis(Collections.singletonList(1L))); + } + org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); + + Operand numExamples = + tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); + + // number of labels (and predictions) per example (after possibly slicing by classIndex) + // In the notation we are using for comments, this is D. + Operand numLabels = + tf.select( + tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), + tf.constant(1), + tf.reduceProd( + // take all but the first dimension + tf.shape.takeLast( + predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), + tf.constant(0))); + + // threshLabelTile == numLabels except in one case: + // if multilabel and rank(thresholds) != 1, then threshLabelTile is 1 + Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); + + // if multilabel, then shape (1, N, Dx) + // else shape (1, ND), + Operand predictionsExtraDim; + Operand labelsExtraDim; + + if (multiLabel) { + predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0)); + labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0)); + } else { + predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1))); + labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1))); + } + + // the shape of each thresholds tile + // if multilabel, then [T, 1, -1] + // else [T, -1] + List> threshPretileShape; + + // the tiling multiples for thresholds + // We want to repeat the thresholds for each data position. + // if multilabel, then [1, N, threshLabelTile]. (threshLabelTile is typically numLabels) + // else [1, ND] + List> threshTiles; + + // tiling multiples for predictionsExtraDim and labelsExtraDim + // We want to repeat the predictions and labels for each threshold. + // If multilabel, then [T, 1, 1] + // else [T, 1] + List> dataTiles; + + if (multiLabel) { + threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); + threshTiles = Arrays.asList(tf.constant(1), numExamples, threshLabelTile); + dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1)); + } else { + threshPretileShape = + Arrays.asList(tf.reshape(numThresholds, tf.constant(Shape.scalar())), tf.constant(-1)); + Operand mul = tf.math.mul(numExamples, numLabels); + threshTiles = Arrays.asList(tf.constant(1), mul); + dataTiles = Arrays.asList(numThresholds, tf.constant(1)); + } + + // if multilabel, then shape (T, 1, T*) + // else shape (T, T*) + // where T* is the product of all threshold dimension sizes beyond 0 + Operand thresholdsReshaped = + tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); + + Operand threshTilesShape = tf.stack(threshTiles); + + // if multilabel, then + // if thresholds has rank > 1, then shape (T, N, T*) + // else shape (T, N, D) + // else shape (T, ND) + Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); + + Operand dataTilesShape = tf.stack(dataTiles); + + // if multilabel, then shape (T, N, D) + // else (T, ND) + Operand predsTiled = tf.tile(predictionsExtraDim, dataTilesShape); + + // Compare predictions and threshold. + Operand predIsPos = tf.math.greater(predsTiled, threshTiled); + // Tile labels by number of thresholds + Operand labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles)); + Operand weightsTiled; + if (tSampleWeight != null) { + tSampleWeight = tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); + // if multilabel, then + // reshape tSampleWeight to (1, N, threshLabelTile) + // tile the result into shape (T, N, threshLabelTile) + // where threshLabelTile is typically D + // else + // reshape tSampleWeight to (1, ND) + // tile the result into shape (T, ND) + weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape); + } else { + weightsTiled = null; + } + + if (labelWeights != null) { + // Change shape to (1, Dx). + Operand lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0)); + + // Broadcast to shape (N, Dx). + lLabelWeights = tf.broadcastTo(lLabelWeights, tPredictions); + + // If multilabel: shape (T, N, D) + // else: shape (T, ND) + Operand labelWeightsTiled = + tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles)); + + if (weightsTiled == null) { + weightsTiled = labelWeightsTiled; + } else { + weightsTiled = tf.math.mul(weightsTiled, labelWeightsTiled); + } + } + + Map loopVars = new HashMap<>(); + loopVars.put(ConfusionMatrixEnum.TRUE_POSITIVES, new Operand[] {labelIsPos, predIsPos}); + Variable updateTN = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES); + Variable updateFP = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES); + Variable updateFN = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES); + + Operand predIsNeg = null; + Operand labelIsNeg; + if (updateFN != null || updateTN != null) { + predIsNeg = tf.math.logicalNot(predIsPos); + loopVars.put(ConfusionMatrixEnum.FALSE_NEGATIVES, new Operand[] {labelIsPos, predIsNeg}); + } + + if (updateFP != null || updateTN != null) { + labelIsNeg = tf.math.logicalNot(labelIsPos); + loopVars.put(ConfusionMatrixEnum.FALSE_POSITIVES, new Operand[] {labelIsNeg, predIsPos}); + if (updateTN != null) { + loopVars.put(ConfusionMatrixEnum.TRUE_NEGATIVES, new Operand[] {labelIsNeg, predIsNeg}); + } + } + + final Operand weightsTiledF = weightsTiled; + loopVars + .keySet() + .forEach( + (c) -> { + if (variablesToUpdate.containsKey(c)) { + Operand[] op = loopVars.get(c); + // op[0] = label, op[1] == prediction + controlOps.add( + weightedAssignAdd( + tf, + op[0], + op[1], + weightsTiledF, + variablesToUpdate.get(c), + varInitializers.get(c))); + } + }); + + return controlOps; + } + + /** + * Creates an Operand that adds the values by taking the logical and of labels and predictions to + * the specified confusion matrix variable. + * + * @param tf The TensorFlow Ops + * @param labels the labels + * @param predictions the predictions + * @param weights the weights applied to the logical and result, may be null + * @param variable the variable to update + * @param initializer the variable initializer to be applied to the variable, may be null. + * @param the data type for the variable. + * @return an Operand that updates the variable. + */ + private static Operand weightedAssignAdd( + Ops tf, + Operand labels, + Operand predictions, + Operand weights, + Variable variable, + Assign initializer) { + Class type = variable.type(); + Operand labelAndPred = cast(tf, tf.math.logicalAnd(labels, predictions), type); + + if (weights != null) { + labelAndPred = tf.math.mul(labelAndPred, weights); + } + // if multilabel: + // sum across examples, leaving shape (T, D) + // else: + // sum across ND, leaving shape (T) + Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); + Operand assignAdd; + if (initializer != null) { + Ops tfc = + tf.withSubScope("weightedAssignAdd") + .withControlDependencies(Collections.singletonList(initializer)); + assignAdd = tfc.assignAdd(variable, valueSum); + } else { + assignAdd = tf.assignAdd(variable, valueSum); + } + return assignAdd; + } + + /** + * Filters top-k values in the last dim of x and set the rest to NEG_INF. + * + *

Used for computing top-k prediction values in dense labels (which has the same shape as + * predictions) for recall and precision top-k metrics. + * + * @param tf The TensorFlow Ops + * @param x the tensor with any dimensions to filter + * @param topK the number of values to keep. + * @param the data type for x and the return value. + * @return the topK prediction values. + */ + private static Operand filterTopK(Ops tf, Operand x, int topK) { + Class type = x.type(); + Shape xShape = x.shape(); + // top has the same rank as x; the last dimension becomes indices of the topK features. + TopK top = tf.nn.topK(x, tf.constant(topK), TopK.sorted(false)); + // oneHot has an additional dimension: the one-hot representation of each topK index. + OneHot oneHot = + tf.oneHot( + top.indices(), + cast(tf, tf.constant(xShape.size(xShape.numDimensions() - 1)), TInt32.class), + tf.constant(1), + tf.constant(0), + OneHot.axis(-1L)); + // Sum the one-hot representations along the last dimension of x. + Operand topKMask = cast(tf, tf.reduceSum(oneHot, tf.constant(-2)), type); + + // x * top_k_mask + NEG_INF * (1 - top_k_mask) + Operand add1 = tf.math.mul(x, topKMask); + Operand add2 = + tf.math.mul( + cast(tf, tf.constant(NEG_INF), type), + tf.math.sub(cast(tf, tf.constant(1), type), topKMask)); + return tf.math.add(add1, add2); + } + + // alias for mean + + /** + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false + * } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -228,8 +706,8 @@ public static Operand mean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is + * {@code false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -247,12 +725,12 @@ public static Operand mean( * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. * @param the type of the operand - * @return the mean of elements of x. + * @return the mean of elements of {@code x}. */ public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); @@ -264,12 +742,12 @@ public static Operand mean(Ops tf, Operand x, boolean * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. * @param the data type of the Operand - * @return the mean of elements of x. + * @return the mean of elements of {@code x}. */ public static Operand mean( Ops tf, Operand x, Operand axes, boolean keepDims) { @@ -279,9 +757,134 @@ public static Operand mean( return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } + public static + LossTuple raggedAssertCompatibleAndGetFlatValues( + Ops tf, Operand labels, Operand predictions) { + // TODO handle ragged Tensors + Operand tLabels = cast(tf, labels, predictions.type()); + return new LossTuple<>(tLabels, predictions); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape {@code [n, n]}, where {@code n} is the + * number of valid labels for a given classification task. Both prediction and labels must be 1-D + * arrays of the same shape in order for this function to work. + * + *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum + * value in either predictions or labels. Class labels are expected to start at 0. For example, if + * {@code numClasses}` is 3, then the possible labels would be {@code [0, 1, 2]}. + * + *

If {@code weights} is not null, then each prediction contributes its corresponding weight to + * the total value of the confusion matrix cell. + * + *

For example: + * + *

{@code
+   *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
+   *          [[0 0 0 0 0]
+   *           [0 0 1 0 0]
+   *           [0 0 1 0 0]
+   *           [0 0 0 0 0]
+   *           [0 0 0 0 1]]
+   * }
+ * + * Note that the possible labels are assumed to be {@code [0, 1, 2, 3,4]}, resulting in a 5x5 + * confusion matrix. + * + * @param tf the TensorFlow Ops + * @param labels 1-D {@code Operand} of real labels for the classification task. + * @param predictions 1-D {@code Operand} of predictions for a given classification. + * @param numClasses The possible number of labels the classification task can have. If this value + * is not provided, it will be calculated using both predictions and labels array. + * @param weights optional weights to be applied to the confusion matrix + * @param type Data type of the confusion matrix. + * @param the type of Operands + * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} + * representing the confusion matrix, where {@code n} is the number of possible labels in + * the classification task. + * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do + * not have compatible shapes, or if {@code weights} is not{@code null} and its + * shape is not compatible with {@code predictions}. + */ + // TODO should this be moved to FramnworkOps under math. + public static Operand confusionMatrix( + Ops tf, + Operand labels, + Operand predictions, + Operand numClasses, + Operand weights, + Class type) { + if (!predictions.shape().isCompatibleWith(labels.shape())) + throw new IllegalArgumentException( + String.format( + "Prediction shape %s is not compatible with labels shape %s", + predictions.shape().toString(), labels.shape().toString())); + tf = tf.withSubScope("confusionMatrix"); + LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, predictions, labels, null); + Operand tPredictions = cast(tf, ops.getTarget(), TInt64.class); + Operand tLabels = cast(tf, ops.getLabels(), TInt64.class); + + List labelControls = new ArrayList<>(); + List predictionControls = new ArrayList<>(); + + labelControls.add( + tf.assertThat( + tf.reduceAny(tf.math.greaterEqual(tLabels, tf.constant(0L)), allAxes(tf, tLabels)), + Collections.singletonList(tf.constant("`labels` contains negative values")))); + + predictionControls.add( + tf.assertThat( + tf.reduceAny( + tf.math.greaterEqual(tPredictions, tf.constant(0L)), allAxes(tf, tPredictions)), + Collections.singletonList(tf.constant("`predictions` contains negative values")))); + if (numClasses == null) { + numClasses = + tf.math.maximum( + tf.reduceMax(tPredictions, allAxes(tf, tPredictions)), + tf.reduceMax(tLabels, allAxes(tf, tLabels))); + } else { + labelControls.add( + tf.assertThat( + tf.reduceAny(tf.math.less(tLabels, numClasses), allAxes(tf, tLabels)), + Collections.singletonList(tf.constant("``labels` out of bounds")))); + predictionControls.add( + tf.assertThat( + tf.reduceAny(tf.math.less(tPredictions, numClasses), allAxes(tf, tPredictions)), + Collections.singletonList(tf.constant("``predictions` out of bounds")))); + } + + if (weights != null) { + if (!tPredictions.shape().isCompatibleWith(weights.shape())) { + throw new IllegalArgumentException( + String.format( + "Prediction shape %s is not compatible with weights shape %s", + tPredictions.shape().toString(), weights.shape().toString())); + } + } + + Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls); + tLabels = tfc.identity(tLabels); + + tfc = tf.withSubScope("confusionMatrixPredictions").withControlDependencies(predictionControls); + tPredictions = tfc.identity(tPredictions); + + Operand shape = tf.stack(Arrays.asList(numClasses, numClasses)); + Operand indices = tf.stack(Arrays.asList(tLabels, tPredictions), Stack.axis(1L)); + Operand values = + weights == null ? cast(tf, tf.onesLike(tPredictions), type) : cast(tf, weights, type); + SparseTensor cmSparse = new SparseTensor<>(indices, values, shape); + Operand zeroMatrix = tf.zeros(shape, type); + + return tf.sparse.sparseTensorDenseAdd( + cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix); + } + /** - * Calculate the mean of the operand, along all axes and keepDims is false - * + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false + * } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -292,8 +895,8 @@ public static Operand booleanMean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is + * {@code false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -310,11 +913,11 @@ public static Operand booleanMean( * * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. - * @return the mean of elements of x containing floating point numbers + * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); @@ -326,11 +929,11 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. - * @return the mean of elements of x containing floating point numbers + * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean( Ops tf, Operand x, Operand axes, boolean keepDims) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java index 2a26967b9f2..b96d2dfa1d0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -106,11 +106,11 @@ public Op resetStates() { } /** - * Updates the metric variables based on the inputs. At least one input arg required for - * values, an optional additional input for the sampleWeights + * Updates the metric variables based on the inputs. At least one input arg required for {@code + * values}, an optional additional input for the {@code sampleWeights} * * @param values the inputs to be passed to update state, this may not be null - * @param sampleWeights sample weights to be applied to values, may be null. + * @param sampleWeights sample weights to be applied to values, will default to 1 if null. * @return the result with a control dependency on update state Operands * @throws IllegalArgumentException if values is null */ @@ -129,13 +129,16 @@ public List updateStateList( if (sampleWeights != null) { tSampleWeights = cast(getTF(), sampleWeights, getResultType()); + // Update dimensions of weights to match with values if possible. LossTuple tuple = LossesHelper.squeezeOrExpandDimensions(getTF(), null, tValues, tSampleWeights); tValues = tuple.getTarget(); tSampleWeights = tuple.getSampleWeights(); try { + // Broadcast weights if possible tSampleWeights = MetricsHelper.broadcastWeights(getTF(), tSampleWeights, tValues); } catch (IllegalArgumentException ex) { + // reduce values to same ndim as weight array // if we get here we have static shapes with either // different ranks or different dimension sizes. // first, reduce the values down to the rank of the samples @@ -162,7 +165,9 @@ public List updateStateList( getTF().assignAdd(total, cast(getTF(), weightedValueSum, total.type())); updateOperations.add(totalUpdate); Operand numValues; + // Exit early if the reduction doesn't have a denominator. if (reduction != MetricReduction.SUM) { + // Update `count` for reductions that require a denominator. switch (reduction) { case SUM_OVER_BATCH_SIZE: numValues = cast(getTF(), getTF().constant(tValues.shape().size()), resultType); @@ -183,6 +188,7 @@ public List updateStateList( throw new UnsupportedOperationException( String.format("reduction [%s] not implemented", reduction)); } + Operand totalCount = getTF().assignAdd(this.count, numValues); updateOperations.add(totalCount); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java new file mode 100644 index 00000000000..60a6c1ea3df --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SensitivitySpecificityBase.java @@ -0,0 +1,286 @@ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Abstract base class for computing sensitivity and specificity. + * + * @param The data type for the metric result + */ +public abstract class SensitivitySpecificityBase extends Metric { + + public static final int DEFAULT_NUM_THRESHOLDS = 200; + + public static final String TRUE_POSITIVES = "TRUE_POSITIVES"; + public static final String FALSE_POSITIVES = "FALSE_POSITIVES"; + public static final String TRUE_NEGATIVES = "TRUE_NEGATIVES"; + public static final String FALSE_NEGATIVES = "FALSE_NEGATIVES"; + protected final int numThresholds; + + protected final float[] thresholds; + private final String truePositivesName; + private final String falsePositivesName; + private final String trueNegativesName; + private final String falseNegativesName; + private final Class type; + protected Variable truePositives; + protected Variable falsePositives; + protected Variable trueNegatives; + protected Variable falseNegatives; + + private Assign truePositivesInitializer; + private Assign falsePositivesInitializer; + private Assign trueNegativesInitializer; + private Assign falseNegativesInitializer; + + /** + * Creates a SensitivitySpecificityBase Metric + * + * @param tf the TensorFlow Ops + * @param name the name of the metric instance, if null then {@link Class#getSimpleName()} is used + * @param numThresholds The number of thresholds to use for matching the given recall. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the variables + * @throws IllegalArgumentException if numThresholds <= 0. + */ + protected SensitivitySpecificityBase( + Ops tf, String name, int numThresholds, long seed, Class type) { + super(tf, name, seed); + if (numThresholds <= 0) throw new IllegalArgumentException("numThresholds must be > 0."); + this.type = type; + this.truePositivesName = this.getVariableName(TRUE_POSITIVES); + this.falsePositivesName = this.getVariableName(FALSE_POSITIVES); + this.trueNegativesName = this.getVariableName(TRUE_NEGATIVES); + this.falseNegativesName = this.getVariableName(FALSE_NEGATIVES); + + this.numThresholds = numThresholds; + + if (this.numThresholds == 1) { + this.thresholds = new float[] {0.5f}; + } else { + this.thresholds = new float[numThresholds]; + for (int i = 0; i < numThresholds - 2; i++) { + this.thresholds[i + 1] = (i + 1f) / (float) (numThresholds - 1); + } + this.thresholds[numThresholds - 1] = 1f; + } + init(); + } + + /** Initializes the Variables */ + private void init() { + Ops tf = getTF(); + Zeros zeros = new Zeros<>(tf); + Shape varShape = Shape.of(numThresholds); + Operand zero = zeros.call(tf.constant(varShape), type); + + if (this.getTruePositives() == null) { + + truePositives = tf.withName(truePositivesName).variable(zero); + truePositivesInitializer = tf.assign(truePositives, zero); + } + if (this.getFalsePositives() == null) { + + falsePositives = tf.withName(falsePositivesName).variable(zero); + falsePositivesInitializer = tf.assign(falsePositives, zero); + } + if (this.getTrueNegatives() == null) { + + trueNegatives = tf.withName(trueNegativesName).variable(zero); + trueNegativesInitializer = tf.assign(trueNegatives, zero); + } + if (this.getFalseNegatives() == null) { + + falseNegatives = tf.withName(falseNegativesName).variable(zero); + falseNegativesInitializer = tf.assign(falseNegatives, zero); + } + } + + /** + * Gets a control dependency Op to initialize all the variables + * + * @return a control dependency Op to initialize all the variables + */ + public Op initializeVariables() { + List varInitializers = new ArrayList<>(); + + if (truePositivesInitializer != null) { + varInitializers.add(truePositivesInitializer); + } + if (falsePositivesInitializer != null) { + varInitializers.add(falsePositivesInitializer); + } + if (trueNegativesInitializer != null) { + varInitializers.add(trueNegativesInitializer); + } + if (falseNegativesInitializer != null) { + varInitializers.add(falseNegativesInitializer); + } + + return getTF().withControlDependencies(varInitializers).noOp(); + } + + /** + * Accumulates confusion matrix statistics. + * + * @param labels The ground truth values. + * @param predictions the predictions + * @param sampleWeights Optional weighting of each example. Defaults to 1. Rank is either 0, or + * the same rank as labels, and must be broadcastable to labels. + * @return a List of Operations to update the metric state. + */ + @Override + @SuppressWarnings("unchecked") + public List updateStateList( + Operand labels, + Operand predictions, + Operand sampleWeights) { + Ops tf = getTF(); + Operand tLabels = cast(tf, labels, type); + Operand tPredictions = cast(tf, predictions, type); + Operand tSampleWeights = sampleWeights != null ? cast(tf, sampleWeights, type) : null; + + Map> confusionMatrix = new HashMap<>(); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_POSITIVES, getTruePositives()); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_POSITIVES, getFalsePositives()); + confusionMatrix.put(ConfusionMatrixEnum.TRUE_NEGATIVES, getTrueNegatives()); + confusionMatrix.put(ConfusionMatrixEnum.FALSE_NEGATIVES, getFalseNegatives()); + + return MetricsHelper.updateConfusionMatrixVariables( + tf, + confusionMatrix, + Collections.EMPTY_MAP, + tLabels, + tPredictions, + tf.constant(thresholds), + null, + null, + tSampleWeights, + false, + null); + } + + /** {@inheritDoc} */ + @Override + public Op resetStates() { + return initializeVariables(); + } + + /** + * Gets the truePositives variable + * + * @return the truePositives + */ + public Variable getTruePositives() { + return truePositives; + } + + /** + * Gets the falsePositives variable + * + * @return the falsePositives truePositives + */ + public Variable getFalsePositives() { + return falsePositives; + } + + /** + * Gets the trueNegatives variable + * + * @return the trueNegatives truePositives + */ + public Variable getTrueNegatives() { + return trueNegatives; + } + + /** + * Gets the falseNegatives variable + * + * @return the falseNegatives truePositives + */ + public Variable getFalseNegatives() { + return falseNegatives; + } + + /** + * Gets the numThresholds + * + * @return the numThresholds + */ + public int getNumThresholds() { + return numThresholds; + } + + + + /** + * Gets the thresholds + * + * @return the thresholds + */ + public float[] getThresholds() { + return thresholds; + } + + /** + * Gets the truePositives variable name + * + * @return the truePositivesName + */ + public String getTruePositivesName() { + return truePositivesName; + } + + /** + * Gets the falsePositives variable name + * + * @return the falsePositivesName + */ + public String getFalsePositivesName() { + return falsePositivesName; + } + + /** + * Gets the trueNegatives variable name + * + * @return the trueNegativesName + */ + public String getTrueNegativesName() { + return trueNegativesName; + } + + /** + * Gets the falseNegatives variable name + * + * @return the falseNegativesName + */ + public String getFalseNegativesName() { + return falseNegativesName; + } + + /** + * Gets the type + * + * @return the type + */ + public Class getType() { + return type; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java index 467dea19b57..68157632557 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -26,16 +26,16 @@ public class SetsOps { /** - * Computes set difference of elements in last dimension of a and b with - * aMinusB set to true. + * Computes set difference of elements in last dimension of {@code a} and {@code b} with + * {@code aMinusB} set to true. * - *

All but the last dimension of a and b must match + *

All but the last dimension of {@code a} and {@code b} must match * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -44,16 +44,16 @@ public static Operand difference(Ops tf, Operand a, Op } /** - * Computes set difference of elements in last dimension of a and b. + * Computes set difference of elements in last dimension of {@code a} and {@code b}. * - *

All but the last dimension of a and b must match + *

All but the last dimension of {@code a} and {@code b} must match * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param aMinusB whether to subtract b from a, vs vice versa. * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -63,13 +63,13 @@ public static Operand difference( } /** - * Computes set union of elements in last dimension of a and b. + * Computes set union of elements in last dimension of {@code a} and {@code b}. * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -78,13 +78,13 @@ public static Operand union(Ops tf, Operand a, Operand } /** - * Computes set intersection of elements in last dimension of a and b. + * Computes set intersection of elements in last dimension of {@code a} and {@code b}. * * @param tf the TensorFlow Ops - * @param a The first operand representing set a - * @param b The other operand representing set b + * @param a The first operand representing set {@code a} + * @param b The other operand representing set {@code b} * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ @@ -93,14 +93,14 @@ public static Operand intersection(Ops tf, Operand a, } /** - * Compute set operation of elements in last dimension of a and b. + * Compute set operation of elements in last dimension of {@code a} and {@code b}. * * @param tf the TensorFlow Ops * @param a The first set operation operand * @param b The other et operation operand * @param setOperation The set operation to perform, {@link Operation}. * @param the data type for the sets - * @return An Operand with the same rank as a and b, and all but the + * @return An Operand with the same rank as {@code a} and {@code b}, and all but the * last dimension the same. Elements along the last dimension contain the results of the set * operation. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java new file mode 100644 index 00000000000..d28185ae041 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SymbolicShape.java @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class SymbolicShape { + private Operand operand; + private List symbols = new ArrayList<>(); + + public SymbolicShape(Operand operand, String... symbols) { + this.operand = operand; + this.symbols.addAll(Arrays.asList(symbols)); + } + + /** @return the operand */ + public Operand getOperand() { + return operand; + } + + /** @param operand the operand to set */ + public void setOperand(Operand operand) { + this.operand = operand; + } + + /** @return the symbols */ + public List getSymbols() { + return symbols; + } + + /** @param symbols the symbols to set */ + public void setSymbols(List symbols) { + this.symbols = symbols; + } + + public int rank() { + return this.symbols.size(); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java new file mode 100644 index 00000000000..6583465da2e --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -0,0 +1,198 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.NoOp; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TNumber; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Weight broadcasting operations. + * + *

In {@link org.tensorflow.framework.losses} and `{@link org.tensorflow.framework.metrics}, we support limited weight broadcasting. This file includes + * operations for those broadcasting rules. + */ +public class WeightsBroadcastOps { + + private static final String ASSERT_BROADCASTABLE_ERROR_PREFIX = + "weights can not be broadcast to values."; + + /** + * Asserts that {@code weights} can be broadcast to {@code values} + * + * @param tf the TensorFlow Ops + * @param weights the weights Operand + * @param values Operand of values to which weights are applied. + * @return {@code Operation} raising a tensorflow InvalidArgumentError if {@code weights} has incorrect shape. {@link NoOp} if + * static checks determine {@code weights} has correct shape. + * @param the type of weights and values + * @throws IllegalArgumentException If static checks determine {@code weights} has incorrect shape. + */ + public static Op assertBroadcastable( + Ops tf, Operand weights, Operand values) { + Operand weightsShape = tf.shape(weights); + Operand weightsRank = tf.rank(weights); + Shape weightsShapeStatic = weights.shape(); + int weightsRankStatic = weightsShapeStatic.numDimensions(); + + Operand valuesShape = tf.shape(values); + Operand valuesRank = tf.rank(values); + Shape valuesShapeStatic = values.asOutput().shape(); + int valuesRankStatic = valuesShapeStatic.numDimensions(); + + if (weightsRankStatic != -1 && valuesRankStatic != -1) { + if (weightsRankStatic == 0) { + return tf.withSubScope("staticScalarCheckSuccess") + .withControlDependencies(Collections.emptyList()) + .noOp(); + } + if (weightsRankStatic != valuesRankStatic) { + throw new IllegalArgumentException( + String.format( + "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", + ASSERT_BROADCASTABLE_ERROR_PREFIX, + valuesRankStatic, + weightsRankStatic, + valuesShapeStatic, + weightsShapeStatic)); + } + + for (int i = 0; i < valuesRankStatic; i++) { + if (weightsShapeStatic.size(i) != 1 && valuesShapeStatic.size(i) != weightsShapeStatic.size(i)) { + throw new IllegalArgumentException( + String.format( + "%s Mismatch at dim %s. values.shape=%s weights.shape=%s.", + ASSERT_BROADCASTABLE_ERROR_PREFIX, + i, + valuesShapeStatic, + weightsShapeStatic)); + } + } + return tf.withSubScope("staticDimsCheckSuccess") + .withControlDependencies(Collections.emptyList()) + .noOp(); + } + // Dynamic checks. + Operand isScalar = tf.math.equal(weightsRank, tf.constant(0)); + List> data = + Arrays.asList( + tf.constant(ASSERT_BROADCASTABLE_ERROR_PREFIX), + tf.constant("weights.shape="), + weightsShape, + tf.constant("values.shape="), + valuesShape, + tf.constant("isScalar="), + isScalar); + + Operand isValidShape = + tf.select( + isScalar, + isScalar, + hasValidNonscalarShape(tf, weightsRank, weightsShape, valuesRank, valuesShape)); + + return tf.assertThat(isValidShape, data); + } + + /** + * Check to see that weights and values have the same rank, if they do, then check each + * corresponding dim of each. + * + * @param tf The TensorFlow Ops + * @param weightsRank the rank operand for the weights + * @param weightsShape the shape operand for the weights + * @param valuesRank the rank operand for the values + * @param valuesShape the shape operand for the values + * @return a boolean Operand, true if both shapes have the same rank, and each dimension is the + * same + */ + private static Operand hasValidNonscalarShape( + Ops tf, + Operand weightsRank, + Operand weightsShape, + Operand valuesRank, + Operand valuesShape) { + tf = tf.withSubScope("hasValidNonscalarShape"); + Operand isSameRank = tf.math.equal(valuesRank, weightsRank); + return tf.select(isSameRank, hasValidDims(tf, weightsShape, valuesShape), isSameRank); + } + + /** + * Checks that each dimension of the two shapes are the same size, or that the weight dimension size is 1. + * + * @param tf the TensorFlow Ops + * @param weightsShape the shape of the weights + * @param valuesShape the shape of the values + * @return a boolean Operand, true if all the dimensions of the two shapes are the same. + */ + private static Operand hasValidDims( + Ops tf, Operand weightsShape, Operand valuesShape) { + tf = tf.withSubScope("hasInvalidDims"); + + Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); + Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); + Operand weightsShape2d = tf.expandDims(weightsShape, tf.constant(-1)); + + Operand invalidDims = SetsOps.difference(tf, weightsShape2d, validDims); + Operand numInvalidDims = tf.size(invalidDims, TInt32.class); + return tf.math.equal(tf.constant(0), numInvalidDims); + } + + /** + * Broadcast {@code weights} to the same shape as {@code values}. + * + *

This returns a version of {@code weights} following the same broadcast rules as {@code + * mul(weights, + * values)}, but limited to the weights shapes allowed by {@code assertBroadcastable} + * When computing a weighted average, use this function to broadcast {@code weights} before + * summing them; e.g., {@code reduceSum(w * v) / reduceSum(_broadcast_weights(w, v))}. + * + * @param tf the TensorFlow ops + * @param weights Operand whose shape is able to be broadcast to {@code values} + * @param values Tensor` of any shape + * @param the type of Operand + * @return {@code weights} broadcast to {@code values} shape + */ + public static Operand broadcastWeights( + Ops tf, Operand weights, Operand values) { + tf = tf.withSubScope("broadcast_weights"); + Operand tValues = cast(tf, values, weights.type()); + + Shape weightsShape = weights.shape(); + Shape valuesShape = tValues.shape(); + + if (!weightsShape.hasUnknownDimension() + && !valuesShape.hasUnknownDimension() + && weightsShape.isCompatibleWith(valuesShape)) { + return weights; + } + + Op dependencies = assertBroadcastable(tf, weights, tValues); + Ops tf1 = + tf.withSubScope("assertBroadcastable") + .withControlDependencies(Collections.singletonList(dependencies)); + return tf1.math.mul(weights, tf.onesLike(tValues)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java new file mode 100644 index 00000000000..81d658ff3a5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/SparseTensor.java @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.utils; + +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.op.SparseOps; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TType; + +/** + * This is a helper class that represents a sparse tensor who's attributes may be passed to {@link + * SparseOps} methods. + * + *

This class does not inherit from {@link Tensor}, but is merely a place to accumulate the + * properties that are needed for the {@link SparseOps} methods. + * + * @param the type of the SparseTensor's values. + */ +public class SparseTensor { + private final Operand indices; + private final Operand values; + private final Operand denseShape; + + /** + * Creates a SparseTensor + * + * @param indices A 2-D int64 tensor of shape `[N, ndims]`, which specifies the + * indices of the elements in the sparse tensor that contain nonzero values + * @param values A 1-D tensor of any type and shape `[N]`, which supplies the + * values for each element in `indices`. + * @param denseShape A 1-D int64 tensor of shape `[ndims]`, which specifies the + * dense_shape of the sparse tensor + * @throws IllegalArgumentException When building an eager SparseTensor if `dense_shape` is + * unknown or contains unknown elements (None or -1). + */ + public SparseTensor (Operand indices, Operand values, Operand denseShape) { + this.indices = indices; + this.values = values; + this.denseShape = denseShape; + } + + /** + * Gets the indices for the Sparse Tensor + * @return the indices + */ + public Operand getIndices() { + return indices; + } + + /** + * Get the values for the Sparse Tensor + * @return the values + */ + public Operand getValues() { + return values; + } + + /** + * Gets the dense shape for the Sparse Tensor + * + * @return the denseShape + */ + public Operand getDenseShape() { + return denseShape; + } + +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java new file mode 100644 index 00000000000..bd693da1312 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AUCTest.java @@ -0,0 +1,373 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class AUCTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float epsilon = 1e-4F; + + int numThresholds = 3; + float[] predArray = new float[] {0f, 0.5f, 0.3f, 0.9f}; + int[] trueArray = new int[] {0, 0, 1, 1}; + float[] sampleWeight = new float[] {1, 2, 3, 4}; + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(predArray); + Operand yTrue = tf.constant(trueArray); + AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + + session.run(tf.init()); + + Op update = instance.updateState(yTrue, yPred, null); + + for (int i = 0; i < 10; i++) { + session.run(update); + } + + Operand result = instance.result(); + + for (int i = 0; i < 10; i++) { + session.evaluate(result, instance.result()); + } + } + } + + @Test + public void testCumulative() { + + // expected variable values after each run. + float[][] tp = {{2f, 1f, 0f}, {4f, 2f, 0f}, {6f, 3f, 0f}}; + float[][] fp = {{2f, 0f, 0f}, {4f, 0f, 0f}, {6f, 0f, 0f}}; + float[][] tn = {{0f, 2f, 2f}, {0f, 4f, 4f}, {0f, 6f, 6f}}; + float[][] fn = {{0f, 1f, 2f}, {0f, 2f, 4f}, {0f, 3f, 6f}}; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(predArray); + Operand yTrue = tf.constant(trueArray); + AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + + session.run(tf.init()); + + assertNull(instance.getTruePositives()); + assertNull(instance.getFalsePositives()); + assertNull(instance.getTrueNegatives()); + assertNull(instance.getFalseNegatives()); + + + + + for (int i = 0; i < 3; i++) { + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + session.evaluate(tp[i], instance.getTruePositives()); + session.evaluate(fp[i], instance.getFalsePositives()); + session.evaluate(tn[i], instance.getTrueNegatives()); + session.evaluate(fn[i], instance.getFalseNegatives()); + } + + // test reset + session.run(instance.resetStates()); + for (int i = 0; i < 3; i++) { + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + session.evaluate(tp[i], instance.getTruePositives()); + session.evaluate(fp[i], instance.getFalsePositives()); + session.evaluate(tn[i], instance.getTrueNegatives()); + session.evaluate(fn[i], instance.getFalseNegatives()); + } + } + } + + @Test + public void basicTestSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + AUC instance = new AUC<>(tf, numThresholds, 1001L, TFloat32.class); + assertEquals(numThresholds, instance.getNumThresholds()); + float[] expectedThresholds = new float[] {-1e-7f, 0.5f, 1 + 1e-7f}; + assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); + + + Operand yPred = tf.constant(new float[] {0, 0, 1, 1}); + Operand yTrue = tf.constant(new float[] {0f, 0.5f, 0.3f, 0.9f}); + Operand sampleWeights = tf.constant(new float[] {1, 0, 0, 1}); + + Op update = instance.updateState(yTrue, yPred, sampleWeights); + session.run(update); + Operand result = instance.result(); + session.evaluate(1.0f, result); + } + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yTrue = cast(tf, tf.constant(this.trueArray), TFloat32.class); + AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + session.run(tf.init()); + + Op update = instance.updateState(yTrue, yTrue, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(1f, result); + } + } + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + Operand result = instance.result(); + + // float expectedResult = (0.75f * 1 + 0.25f * 0); + session.evaluate(0.75f, result); + } + } + + @Test + public void testManualThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + AUC instance = new AUC<>(tf, new float[] {0.5f}, 1001L, TFloat32.class); + float[] expectedThresholds = new float[] {-AUC.EPSILON, 0.5f, 1 + AUC.EPSILON}; + assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, null); + session.run(update); + Operand result = instance.result(); + + // float expectedResult = (0.75f * 1 + 0.25f * 0); + session.evaluate(0.75f, result); + } + } + + @Test + public void testWeightedRocInterpolation() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = new AUC<>(tf, this.numThresholds, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + + float expectedResult = (0.78571427f * 1 + 0.2857145f * 0); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedRocMajoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.ROC, + AUCSummationMethod.MAJORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + + float expectedResult = (1.0f + .5714285f * 0f); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedRocMinoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.ROC, + AUCSummationMethod.MINORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + + float expectedResult = (0.5714285f + 0f * 0f); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedPrMajoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.PR, + AUCSummationMethod.MAJORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.4285715f + 0.5714285f; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedPrMinoring() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>( + tf, + this.numThresholds, + AUCCurve.PR, + AUCSummationMethod.MINORING, + 1001L, + TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.7f * 0.4285715f + 0f * 0.5714285f; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testWeightedPrInterpolation() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand yPred = tf.constant(this.predArray); + Operand yTrue = tf.constant(this.trueArray); + Operand sampleWights = tf.constant(this.sampleWeight); + + AUC instance = + new AUC<>(tf, this.numThresholds, AUCCurve.PR, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(yTrue, yPred, sampleWights); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.916613f; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + new AUC<>(tf, -1, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testExtraDims() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // logits = scipy.special.expit(-np.array([[[-10., 10., -10.], [10., -10., 10.]], + // [[-12., 12., -12.], [12., -12., 12.]]], + // dtype=np.float32)) + float[][][] logitsArray = { + { + {9.99954602e-01f, 4.53978687e-05f, 9.99954602e-01f}, + {4.53978687e-05f, 9.99954602e-01f, 4.53978687e-05f} + }, + { + {9.99993856e-01f, 6.14417460e-06f, 9.99993856e-01f}, + {6.14417460e-06f, 9.99993856e-01f, 6.14417460e-06f} + } + }; + + long[][][] labelArray = { + {{1, 0, 0}, {1, 0, 0}}, + {{0, 1, 1}, {0, 1, 1}} + }; + + Operand logits = tf.constant(logitsArray); + Operand labels = tf.constant(labelArray); + + AUC instance = new AUC<>(tf, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(labels, logits, null); + session.run(update); + Operand result = instance.result(); + float expectedResult = 0.5f; + session.evaluate(expectedResult, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java new file mode 100644 index 00000000000..48cac95b8a6 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/AccuracyTest.java @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class AccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2, 3, 4}; + float[] predArray = {1, 2, 3, 4}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 1))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4F, total); + session.evaluate(4, count); + session.evaluate(1F, result); + } + } + + @Test + public void testSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {2, 1}; + float[] predArray = {2, 0}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(.5F, total); + session.evaluate(.7, count); + session.evaluate(0.71428573f, result); + } + } + + @Test + public void testVariableState() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Accuracy instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {2, 1}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // 2nd run + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(1.4F, total); + session.evaluate(1.4, count); + session.evaluate(1.0F, result); + + // new instance same graph + instance = new Accuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + total = instance.getTotal(); + count = instance.getCount(); + result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // reset variables + session.run(instance.resetStates()); + result = instance.result(); + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java new file mode 100644 index 00000000000..d203815f4ab --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryAccuracyTest.java @@ -0,0 +1,176 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class BinaryAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 0}; + float[] predArray = {1, 0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(2, count); + session.evaluate(1F, result); + } + } + + @Test + public void testPredictionSqueeze() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 0}; + float[] predArray = {1, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1, 1))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(4, count); + session.evaluate(0.5F, result); + } + } + + @Test + public void testSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 1}; + float[] predArray = {1, 0}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.5F, total); + session.evaluate(.7, count); + session.evaluate(0.71428573f, result); + } + } + + @Test + public void testVariableState() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + float[] trueArray = {2, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 1))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.2F, total); + session.evaluate(.7, count); + session.evaluate(0.2857143F, result); + + // 2nd run + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(0.4F, total); + session.evaluate(1.4, count); + session.evaluate(0.2857143F, result); + + // new instance same graph + instance = new BinaryAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + total = instance.getTotal(); + count = instance.getCount(); + result = instance.result(); + session.evaluate(0.2F, total); + session.evaluate(.7, count); + session.evaluate(0.2857143F, result); + + // reset variables + session.run(instance.resetStates()); + session.evaluate(0.0, total); + session.evaluate(0.0, count); + + op = instance.updateState(labels, labels, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(0.2F, total); + session.evaluate(.7, count); + session.evaluate(0.2857143F, result); + } + } + + @Test + public void testBinaryAccuracyAThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryAccuracy instance = new BinaryAccuracy<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 1, 0, 0}; + float[] predArray = {0.9f, 0.6f, 0.4f, 0.8f}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 1))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 1))); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(4, count); + session.evaluate(0.5F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java new file mode 100644 index 00000000000..aea2e4e0d6e --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalAccuracyTest.java @@ -0,0 +1,153 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +public class CategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 0, 1, + 0, 1, 0 + }; + float[] predArray = { + 0.1f, 0.1f, 0.8f, + 0.05f, 0.95f, 0f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2F, total); + session.evaluate(2, count); + session.evaluate(1F, result); + } + } + + @Test + public void testSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 0, 1, + 0, 1, 0 + }; + float[] predArray = { + 0.1f, 0.1f, 0.8f, + 0.05f, 0.95f, 0f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + } + } + + @Test + public void testVariableState() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalAccuracy instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 0, 1, + 0, 1, 0 + }; + float[] predArray = { + 0.1f, 0.1f, 0.8f, + 0.05f, 0.95f, 0f + }; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = + tf.reshape(tf.constant(new float[] {.5F, .2F}), tf.constant(Shape.of(2, 1))); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // 2nd run + op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(1.4F, total); + session.evaluate(1.4, count); + session.evaluate(1.0F, result); + + // new instance same graph + instance = new CategoricalAccuracy<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + result = instance.result(); + total = instance.getTotal(); + count = instance.getCount(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + + // reset variables + session.run(instance.resetStates()); + session.evaluate(0, total); + session.evaluate(0, count); + session.evaluate(0, result); + + op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + result = instance.result(); + session.evaluate(0.7F, total); + session.evaluate(.7, count); + session.evaluate(1.0F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java new file mode 100644 index 00000000000..4bd8d99586e --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalseNegativesTest.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class FalseNegativesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + FalseNegatives instance = new FalseNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(3.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + FalseNegatives instance = new FalseNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(5.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + FalseNegatives instance = + new FalseNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {1.f, 4.f, 6.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = tf.constant(new double[][] {{3.0}, {5.0}, {7.0}, {4.0}}); + FalseNegatives instance = + new FalseNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {4., 16., 23.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java new file mode 100644 index 00000000000..2584c7a3244 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/FalsePositivesTest.java @@ -0,0 +1,148 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class FalsePositivesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + FalsePositives instance = new FalsePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(7.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + FalsePositives instance = new FalsePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(14.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + FalsePositives instance = + new FalsePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {7.f, 4.f, 2.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = + tf.constant( + new double[][] { + {1.0, 2.0, 3.0, 5.0}, + {7.0, 11.0, 13.0, 17.0}, + {19.0, 23.0, 29.0, 31.0}, + {19.0, 23.0, 29.0, 31.0} + }); + FalsePositives instance = + new FalsePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {125., 42., 12.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java new file mode 100644 index 00000000000..fc08455d1c7 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanIoUTest.java @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +public class MeanIoUTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final long numClasses = 2L; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testUnweighted"); + Operand predictions = tf.constant(new long[] {0, 1, 0, 1}); + Operand labels = tf.constant(new long[] {0, 0, 1, 1}); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + double expected_result = (1. / (2. + 2. - 1.) + 1. / (2. + 2. - 1.)) / 2.; + session.evaluate(expected_result, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testWeighted"); + Operand predictions = tf.constant(new long[] {0, 1, 0, 1}); + Operand labels = tf.constant(new long[] {0, 0, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {0.2f, 0.3f, 0.4f, 0.1f}); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + float expected_result = (0.2f / (0.6f + 0.5f - 0.2f) + 0.1f / (0.4f + 0.5f - 0.1f)) / 2f; + session.evaluate(expected_result, result); + } + } + + @Test + public void testMultiDimInput() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testMultiDimInput"); + + Operand predictions = tf.constant(new long[][] {{0, 1}, {0, 1}}); + Operand labels = tf.constant(new long[][] {{0, 0}, {1, 1}}); + Operand sampleWeight = tf.constant(new float[][] {{0.2f, 0.3f}, {0.4f, 0.1f}}); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + float expected_result = (0.2f / (0.6f + 0.5f - 0.2f) + 0.1f / (0.4f + 0.5f - 0.1f)) / 2f; + session.evaluate(expected_result, result); + } + } + + @Test + public void testZeroValidEntries() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testZeroValidEntries"); + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Operand result = instance.result(); + session.evaluate(0.0f, result); + } + } + + @Test + public void testZeroAndNonZeroEntries() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF().withSubScope("testZeroAndNonZeroEntries"); + Operand predictions = tf.constant(new float[] {1}); + Operand labels = tf.constant(new int[] {1}); + + MeanIoU instance = new MeanIoU<>(tf, numClasses, 1001L, TFloat32.class); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float expected_result = (0f + 1f / (1f + 1f - 1f)) / 1f; + session.evaluate(expected_result, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java new file mode 100644 index 00000000000..0bb9392b8b0 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanRelativeErrorTest.java @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class MeanRelativeErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] predArray = new float[][] {{2, 4, 6, 8}}; + float[][] trueArray = new float[][] {{1, 3, 2, 3}}; + Operand predictions = tf.constant(predArray); + Operand labels = tf.constant(trueArray); + + MeanRelativeError instance = + new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + double expected_result = 1.25; + session.evaluate(expected_result, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[] predArray = new float[] {2, 4, 6, 8}; + float[] trueArray = new float[] {1, 3, 2, 3}; + float[] sampleWeightArray = new float[] {0.2f, 0.3f, 0.5f, 0f}; + + Operand predictions = tf.constant(predArray); + Operand labels = tf.constant(trueArray); + Operand sampleWeight = tf.constant(sampleWeightArray); + + MeanRelativeError instance = + new MeanRelativeError<>(tf, labels, 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + double expectedResult = 1.3; + session.evaluate(expectedResult, result); + } + } + + @Test + public void testZeroNormalizer() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[] predArray = new float[] {2, 4}; + int[] trueArray = new int[] {1, 3}; + + Operand predictions = tf.constant(predArray); + Operand labels = tf.constant(trueArray); + + MeanRelativeError instance = + new MeanRelativeError<>( + tf, cast(tf, tf.zerosLike(labels), TFloat32.class), 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.run(tf.init()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + double expectedResult = 0; + session.evaluate(expectedResult, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java new file mode 100644 index 00000000000..ce473bbdf34 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanTensorTest.java @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class MeanTensorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand values = tf.constant(new long[] {100, 40}); + MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); + session.run(tf.init()); + Op update = instance.updateState(values, null); + session.run(update); + Operand result = instance.result(); + double[] expected_result = new double[] {100, 40}; + session.evaluate(expected_result, result); + + session.evaluate(expected_result, instance.getTotal()); + session.evaluate(new double[] {1, 1}, instance.getCount()); + + session.run(instance.resetStates()); + session.evaluate(new double[] {0, 0}, instance.getTotal()); + session.evaluate(new double[] {0, 0}, instance.getCount()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand values = tf.constant(new long[] {100, 30}); + MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat64.class); + session.run(tf.init()); + + // check scalar weight + Op update = instance.updateState(values, tf.constant(0.5f)); + session.run(update); + Operand result = instance.result(); + double[] expected_result = new double[] {100, 30}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {50, 15}, instance.getTotal()); + session.evaluate(new double[] {0.5, 0.5}, instance.getCount()); + + // check weights not scalar and weights rank matches values rank + values = tf.constant(new long[] {1, 5}); + update = instance.updateState(values, tf.constant(new double[] {1f, 0.2f})); + session.run(update); + result = instance.result(); + expected_result = new double[] {51 / 1.5, 16 / 0.7}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {51, 16}, instance.getTotal()); + session.evaluate(new double[] {1.5, .7}, instance.getCount()); + + // check weights broadcast + values = tf.constant(new long[] {1, 2}); + update = instance.updateState(values, tf.constant(0.5f)); + session.run(update); + result = instance.result(); + expected_result = new double[] {51.5 / 2, 17 / 1.2}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {51.5, 17}, instance.getTotal()); + session.evaluate(new double[] {2, 1.2}, instance.getCount()); + + // check weights squeeze + values = tf.constant(new long[] {1, 5}); + Operand sampleWeight = tf.constant(new double[][] {{1}, {0.2}}); + update = instance.updateState(values, sampleWeight); + session.run(update); + result = instance.result(); + expected_result = new double[] {52.5 / 3, 18 / 1.4}; + session.evaluate(expected_result, result); + session.evaluate(new double[] {52.5, 18}, instance.getTotal()); + session.evaluate(new double[] {3, 1.4}, instance.getCount()); + } + } + + @Test + public void testWeightedExpand() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + // check weights expand + MeanTensor instance = new MeanTensor<>(tf, 1001L, TFloat32.class); + + Operand values = tf.constant(new long[][] {{1}, {5}}); + Operand sampleWeight = tf.constant(new float[] {1f, 0.2f}); + Op update = instance.updateState(values, sampleWeight); + session.run(update); + Operand result = instance.result(); + session.evaluate(tf.constant(new float[][] {{1f}, {5f}}), result); + session.evaluate(tf.constant(new float[][] {{1f}, {1f}}), instance.getTotal()); + session.evaluate(tf.constant(new float[][] {{1f}, {0.2f}}), instance.getCount()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java new file mode 100644 index 00000000000..a817a3dc5df --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionAtRecallTest.java @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class PrecisionAtRecallTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialPrecision = instance.result(); + + for (int i = 0; i < 10; i++) session.evaluate(initialPrecision, instance.result()); + } + } + + private int[][] generateRandomArray(int dim1, int dim2) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(2); + } + } + + return result; + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighRecall() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.8f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.5f, 0.4f, 0.5f, 0.6f, 0.8f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.8f, precision); + } + } + + @Test + public void testUnweightedLowRecall() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.15f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.5f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + PrecisionAtRecall instance = + new PrecisionAtRecall<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {2, 2, 1, 1, 1, 1, 1, 2, 2, 2}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(2.f / 3.f, precision); + } + } + + @Test + public void testInvalidSensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new PrecisionAtRecall<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new PrecisionAtRecall<>(tf, 0.7f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java new file mode 100644 index 00000000000..cfe5b483e2b --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PrecisionTest.java @@ -0,0 +1,333 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +public class PrecisionTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Precision instance = + new Precision<>(tf, new float[] {0.3f, 0.72f}, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1001L)); + + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) { + session.run(update); + } + + Operand initialPrecision = instance.result(); + + for (int i = 0; i < 10; i++) { + session.evaluate(initialPrecision, instance.result()); + } + } + } + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new long[][] {{1, 0, 1, 0}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(0.5, precision); + } + } + + @Test + public void testUnweightedAllIncorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 0.5f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniformInt(tf.constant(Shape.of(100, 1)), tf.constant(0), tf.constant(2)); + Operand labels = tf.math.sub(tf.constant(1), predictions); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(0.0f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new long[][] {{1, 0, 1, 0}, {1, 0, 1, 0}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}, {1, 0, 0, 1}}); + Operand sampleWeight = tf.constant(new double[][] {{1, 2, 3, 4}, {4, 3, 2, 1}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand precision = instance.result(); + + double weightedTP = 3.0f + 4.0f; + double weightedPositives = (1.0f + 3.0f) + (4.0f + 2.0f); + double expectedPrecision = weightedTP / weightedPositives; + + session.evaluate(expectedPrecision, precision); + } + } + + @Test + public void testDivByZero() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = new Precision<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new int[] {0, 0, 0, 0}); + Operand labels = tf.constant(new int[] {0, 0, 0, 0}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(0, precision); + } + } + + @Test + public void testUnweightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = + new Precision<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1f, 0f, 0.6f, 0f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + float[] expected = new float[] {0.5f, 0.f}; + + session.evaluate(expected, precision); + } + } + + @Test + public void testWeightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = + new Precision<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1f, 0f}, {0.6f, 0f}}); + Operand labels = tf.constant(new long[][] {{0, 1}, {1, 0}}); + Operand sampleWeight = tf.constant(new float[][] {{4, 0}, {3, 1}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand precision = instance.result(); + + float weightedTP = 0f + 3.f; + float weightedPositives = (0f + 3.f) + (4.f + 0.f); + float expectedPrecision = weightedTP / weightedPositives; + + Float[] expected = new Float[] {expectedPrecision, 0f}; + session.evaluate(expected, precision); + } + } + + @Test + public void testMultipleUpdates() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Precision instance = + new Precision<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1f, 0f}, {0.6f, 0f}}); + Operand labels = tf.constant(new long[][] {{0, 1}, {1, 0}}); + Operand sampleWeight = tf.constant(new double[][] {{4, 0}, {3, 1}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + for (int i = 0; i < 2; i++) session.run(update); + Operand precision = instance.result(); + + double weighted_tp = (0 + 3.) + (0 + 3.); + double weighted_positives = ((0 + 3.) + (4. + 0.)) + ((0 + 3.) + (4. + 0.)); + double expected_precision = weighted_tp / weighted_positives; + + double[] expected = new double[] {expected_precision, 0f}; + session.evaluate(expected, precision); + } + } + + @Test + public void testUnweightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK to 3 + Precision instance = new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + session.evaluate(1.0f / 3.0f, precision); + } + } + + @Test + public void testWeightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK to 3 + Precision instance = new Precision<>(tf, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[] {0.2f, 0.1f, 0.4f, 0f, 0.2f}); + Operand labels = tf.constant(new long[] {0, 1, 1, 0, 1}); + Operand sampleWeight = tf.constant(new float[][] {{1, 4, 2, 3, 5}}); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.4f, 0.2f, 0.2f}}); + labels = tf.constant(new long[][] {{1, 0, 1, 1, 1}}); + update = instance.updateState(labels, predictions, tf.constant(3.f)); + session.run(update); + + Operand precision = instance.result(); + + float tp = (2f + 5f) + (3f + 3f); + float predicted_positives = (1f + 2f + 5f) + (3f + 3f + 3f); + float expected_precision = tp / predicted_positives; + session.evaluate(expected_precision, precision); + } + } + + @Test + public void testUnweightedClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set classId to 2 + Precision instance = new Precision<>(tf, null, null, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0f, 0f, 0.2f}}); + labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + labels = tf.constant(new long[][] {{0, 1, 0, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + precision = instance.result(); + + session.evaluate(0.5f, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(1, instance.getFalsePositives()); + } + } + + @Test + public void testUnweightedTopKAndClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK and classId to 2 + Precision instance = new Precision<>(tf, null, 2, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + + predictions = tf.constant(new float[][] {{1f, 1f, 0.9f, 1f, 1f}}); + labels = tf.constant(new long[][] {{0, 1, 1, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + } + } + + @Test + public void testUnweightedTopKAndThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + // set topK to 2 + Precision instance = new Precision<>(tf, 0.7f, 2, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new long[][] {{0, 1, 1, 0, 1}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1, precision); + session.evaluate(1, instance.getTruePositives()); + session.evaluate(0, instance.getFalsePositives()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java new file mode 100644 index 00000000000..bd3a5273668 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallAtPrecisionTest.java @@ -0,0 +1,207 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class RecallAtPrecisionTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + labels = tf.math.mul(labels, tf.constant(2.0f)); + + Op update = instance.updateState(labels, predictions); + + for (int i = 0; i < 10; i++) { + session.run(update); + } + + Operand initialPrecision = instance.result(); + + for (int i = 0; i < 10; i++) { + session.evaluate(initialPrecision, instance.result()); + } + } + } + + private int[][] generateRandomArray(int dim1, int dim2, int maxVal) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(maxVal); + } + } + + return result; + } + + @Test + public void test_unweighted_all_correct() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1, 2); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighPrecision() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.75f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] { + 0.05f, 0.1f, 0.2f, 0.3f, 0.3f, 0.35f, 0.4f, 0.45f, 0.5f, 0.6f, 0.9f, 0.95f + }); + Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.5f, precision); + } + } + + @Test + public void testUnweightedLowPrecision() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 2.0f / 3f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] { + 0.05f, 0.1f, 0.2f, 0.3f, 0.3f, 0.35f, 0.4f, 0.45f, 0.5f, 0.6f, 0.9f, 0.95f + }); + Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(5.f / 6f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 0.75f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.1f, 0.2f, 0.3f, 0.5f, 0.6f, 0.9f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 1, 0, 0, 0, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {1, 2, 1, 2, 1, 2, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.6f, precision); + } + } + + @Test + public void testUnachievablePrecision() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RecallAtPrecision instance = + new RecallAtPrecision<>(tf, 2.0f / 3f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = tf.constant(new float[] {0.1f, 0.2f, 0.3f, 0.9f}); + Operand labels = tf.constant(new long[] {1, 1, 0, 0}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + // The highest possible precision is 1/2 which is below the required + session.evaluate(0f, precision); + } + } + + @Test + public void test_invalid_sensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new RecallAtPrecision<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void test_invalid_num_thresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new RecallAtPrecision<>(tf, 0.7f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java new file mode 100644 index 00000000000..bd9fbb1ab66 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RecallTest.java @@ -0,0 +1,333 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Random; + +public class RecallTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = + new Recall<>(tf, new float[] {0.3f, 0.72f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class); + Operand labels = + tf.random.randomUniform(tf.constant(Shape.of(10, 3)), TFloat32.class); + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialRecall = instance.result(); + for (int i = 0; i < 10; i++) session.evaluate(initialRecall, instance.result()); + } + } + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1, 0, 1, 0}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.5f, instance.result()); + } + } + + private int[][] generateRandomArray(int dim1, int dim2, int maxInt) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(maxInt); + } + } + + return result; + } + + @Test + public void testUnweightedAllIncorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] array = generateRandomArray(100, 1, 2); + Operand predictions = tf.dtypes.cast(tf.constant(array), TFloat32.class); + Operand labels = + tf.dtypes.cast(tf.math.sub(tf.constant(1), tf.constant(array)), TFloat32.class); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.f, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[][] { + {1, 0, 1, 0}, + {0, 1, 0, 1} + }); + Operand labels = + tf.constant( + new float[][] { + {0, 1, 1, 0}, + {1, 0, 0, 1} + }); + + Operand sampleWeights = + tf.constant( + new float[][] { + {1, 2, 3, 4}, + {4, 3, 2, 1} + }); + Op update = instance.updateState(labels, predictions, sampleWeights); + session.run(update); + + float weightedTp = 3.0f + 1.0f; + float weightedT = (2.0f + 3.0f) + (4.0f + 1.0f); + float expectedRecall = weightedTp / weightedT; + + session.evaluate(expectedRecall, instance.result()); + } + } + + @Test + public void testDivByZero() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[] {0, 0, 0, 0}); + Operand labels = tf.constant(new float[] {0, 0, 0, 0}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0f, instance.result()); + } + } + + @Test + public void testUnweightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, new float[] {0.5f, 0.7f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{1, 0, 0.6f, 0}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Float[] expected = new Float[] {0.5f, 0f}; + session.evaluate(expected, instance.result()); + } + } + + @Test + public void testWeightedWithThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); + Operand predictions = tf.constant(new float[][] {{1, 0}, {0.6f, 0}}); + Operand weights = tf.constant(new float[][] {{1, 4}, {3, 2}}); + + Op update = instance.updateState(labels, predictions, weights); + session.run(update); + + float weightedTp = 0 + 3.f; + float weightedPositives = (0 + 3.f) + (4.f + 0.f); + float expectedRecall = weightedTp / weightedPositives; + float[] expected = new float[] {expectedRecall, 0f}; + session.evaluate(expected, instance.result()); + } + } + + @Test + public void testMultipleUpdates() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, new float[] {0.5f, 1.f}, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0, 1}, {1, 0}}); + Operand predictions = tf.constant(new float[][] {{1, 0}, {0.6f, 0}}); + Operand weights = tf.constant(new float[][] {{1, 4}, {3, 2}}); + + Op update = instance.updateState(labels, predictions, weights); + for (int i = 0; i < 2; i++) session.run(update); + + float weightedTp = (0f + 3.f) + (0f + 3.f); + float weightedPositives = ((0f + 3.f) + (4.f + 0.f)) + ((0f + 3.f) + (4.f + 0.f)); + float expectedRecall = weightedTp / weightedPositives; + float[] expected = new float[] {expectedRecall, 0f}; + session.evaluate(expected, instance.result()); + } + } + + @Test + public void testUnweightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0f, 1f, 1f, 0f, 0f}}); + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.5f, 0f, 0.2f}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.5f, instance.result()); + } + } + + @Test + public void testWeightedTopK() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, null, null, 3, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 1}}); + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.4f, 0f, 0.2f}}); + Operand weights = tf.constant(new float[][] {{1, 4, 2, 3, 5}}); + + Op update = instance.updateState(labels, predictions, weights); + session.run(update); + + labels = tf.constant(new float[][] {{1, 0, 1, 1, 1}}); + predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.4f, 0.2f, 0.2f}}); + weights = tf.constant(3.f); + + update = instance.updateState(labels, predictions, weights); + session.run(update); + + float weightedTp = (2 + 5) + (3 + 3); + float weightedPositives = (4 + 2 + 5) + (3 + 3 + 3 + 3); + float expectedRecall = weightedTp / weightedPositives; + session.evaluate(expectedRecall, instance.result()); + } + } + + @Test + public void testUnweightedClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, null, null, null, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(0f, instance.getFalseNegatives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0f, 0f, 0.2f}}); + labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(1f, instance.getFalseNegatives()); + + predictions = tf.constant(new float[][] {{0.2f, 0.1f, 0.6f, 0f, 0.2f}}); + labels = tf.constant(new float[][] {{0, 1, 0, 0, 0}}); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(1f, instance.getFalseNegatives()); + } + } + + @Test + public void testUnweightedTopKAndClassId() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, null, null, 2, 2, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.6f, 0.3f, 0, 0.2f}}); + Operand labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(1f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(0f, instance.getFalseNegatives()); + + predictions = tf.constant(new float[][] {{1, 1, 0.9f, 1, 1}}); + labels = tf.constant(new float[][] {{0, 1, 1, 0, 0}}); + + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(1f, instance.getFalseNegatives()); + } + } + + @Test + public void testUnweightedTopKAndThreshold() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Recall instance = new Recall<>(tf, null, 0.7f, 2, null, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = tf.constant(new float[][] {{0.2f, 0.8f, 0.6f, 0f, 0.2f}}); + Operand labels = tf.constant(new float[][] {{1, 1, 1, 0, 1}}); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + session.evaluate(0.25f, instance.result()); + session.evaluate(1f, instance.getTruePositives()); + session.evaluate(3f, instance.getFalseNegatives()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java new file mode 100644 index 00000000000..c9ced9f5946 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/RootMeanSquaredErrorTest.java @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +public class RootMeanSquaredErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RootMeanSquaredError instance = + new RootMeanSquaredError<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand labels = tf.constant(new float[] {2, 4, 6}); + Operand predictions = tf.constant(new float[] {1, 3, 2}); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(18, total); + session.evaluate(3, count); + session.evaluate(Math.sqrt(6), result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + RootMeanSquaredError instance = + new RootMeanSquaredError<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand labels = tf.constant(new float[][] {{2, 4, 6, 8}}); + Operand predictions = tf.constant(new float[][] {{1, 3, 2, 3}}); + Operand sampleWeight = tf.constant(new double[][] {{0, 1, 0, 1}}); + + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(26, total); + session.evaluate(2, count); + session.evaluate(Math.sqrt(13), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java new file mode 100644 index 00000000000..a65dc3b53da --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SensitivityAtSpecificityTest.java @@ -0,0 +1,185 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class SensitivityAtSpecificityTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + labels = tf.math.mul(labels, tf.constant(2.0f)); + + // instance.setDebug(session.getGraphSession()); + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialSensitivity = instance.result(); + + for (int i = 0; i < 10; i++) session.evaluate(initialSensitivity, instance.result()); + + // instance.setDebug(null); + + } + } + + private int[][] generateRandomArray(int dim1, int dim2) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(2); + } + } + + return result; + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = cast(tf, tf.constant(trueArray), TFloat32.class); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighSpecificity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.8f, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.45f, 0.5f, 0.8f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.8, precision); + } + } + + @Test + public void testUnweightedLowSpecificity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.6f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SensitivityAtSpecificity instance = + new SensitivityAtSpecificity<>(tf, 0.4f, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + Operand sampleWeight = tf.constant(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.675, precision); + } + } + + @Test + public void testInvalidSensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SensitivityAtSpecificity<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SensitivityAtSpecificity<>(tf, 0.7f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java new file mode 100644 index 00000000000..ff5834eda8e --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SpecificityAtSensitivityTest.java @@ -0,0 +1,184 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class SpecificityAtSensitivityTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(); + + @Test + public void testValueIsIdempotent() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + + Operand predictions = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + Operand labels = + tf.random.randomUniform( + tf.constant(Shape.of(10, 3)), TFloat32.class, RandomUniform.seed(1L)); + + // instance.setDebug(session.getGraphSession()); + Op update = instance.updateState(labels, predictions, null); + + for (int i = 0; i < 10; i++) session.run(update); + + Operand initialSpecificity = instance.result(); + + for (int i = 0; i < 10; i++) session.evaluate(initialSpecificity, instance.result()); + } + } + + private int[][] generateRandomArray(int dim1, int dim2) { + int[][] result = new int[dim1][dim2]; + for (int i = 0; i < dim1; i++) { + for (int j = 0; j < dim2; j++) { + result[i][j] = random.nextInt(2); + } + } + + return result; + } + + @Test + public void testUnweightedAllCorrect() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.7f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[][] predArray = generateRandomArray(100, 1); + int[][] trueArray = new int[100][1]; // 100,1 + System.arraycopy(predArray, 0, trueArray, 0, predArray.length); + Operand predictions = cast(tf, tf.constant(predArray), TFloat32.class); + Operand labels = tf.constant(trueArray); + labels = tf.math.mul(labels, tf.constant(2)); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(1f, precision); + } + } + + @Test + public void testUnweightedHighSensitivity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.8f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant(new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.1f, 0.45f, 0.5f, 0.8f, 0.9f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.4f, precision); + } + } + + @Test + public void testUnweightedLowSensitivity() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.4f, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.6f, precision); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SpecificityAtSensitivity instance = + new SpecificityAtSensitivity<>(tf, 0.4f, 1001L, TFloat32.class); + session.run(instance.resetStates()); + Operand predictions = + tf.constant( + new float[] {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.01f, 0.02f, 0.25f, 0.26f, 0.26f}); + Operand labels = tf.constant(new long[] {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}); + Operand sampleWeight = tf.constant(new float[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + + Operand precision = instance.result(); + + session.evaluate(0.4f, precision); + } + } + + @Test + public void testInvalidSensitivity() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SpecificityAtSensitivity<>(tf, -1f, 1001L, TFloat32.class); + } + }); + } + + @Test + public void testInvalidNumThresholds() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + new SpecificityAtSensitivity<>(tf, 0.4f, -1, 1001L, TFloat32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java new file mode 100644 index 00000000000..941f882b8c8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SumTest.java @@ -0,0 +1,113 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SumTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + /** Test of call method, of class Sum. */ + @Test + public void testUnWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Sum instance = new Sum<>(tf, 1001L, TFloat32.class); + session.run(instance.resetStates()); + assertEquals(TFloat32.class, instance.getResultType()); + session.evaluate(0f, instance.getTotal()); + + Op update = instance.updateState(tf.constant(100f), null); + session.run(update); + session.evaluate(100f, instance.result()); + session.evaluate(100f, instance.getTotal()); + + update = instance.updateState(tf.constant(new float[] {1, 5}), null); + session.run(update); + session.evaluate(106f, instance.result()); + session.evaluate(106f, instance.getTotal()); + + session.run(instance.resetStates()); + session.evaluate(0f, instance.getTotal()); + } + } + + @Test + public void testSumWithSampleWeight() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Sum instance = new Sum<>(tf, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + // check scalar weight + Op op = instance.updateState(tf.constant(100f), tf.constant(0.5)); + session.run(op); + Operand result = instance.result(); + session.evaluate(50.0, instance.getTotal()); + session.evaluate(50.0, result); + + // check weights not scalar and weights rank matches values rank + op = + instance.updateState(tf.constant(new float[] {1, 5}), tf.constant(new double[] {1, 0.2})); + session.run(op); + result = instance.result(); + session.evaluate(52., instance.getTotal()); + session.evaluate(52., result); + + // check weights broadcast + op = instance.updateState(tf.constant(new float[] {1, 2}), tf.constant(0.5)); + session.run(op); + result = instance.result(); + session.evaluate(53.5, instance.getTotal()); + session.evaluate(53.5, result); + + // check weights squeeze + op = + instance.updateState( + tf.constant(new float[] {1, 5}), tf.constant(new double[][] {{1}, {0.2}})); + session.run(op); + result = instance.result(); + session.evaluate(55.5, instance.getTotal()); + session.evaluate(55.5, result); + + // check weights expand + op = + instance.updateState( + tf.constant(new float[][] {{1}, {5}}), tf.constant(new double[] {1, 0.2})); + session.run(op); + result = instance.result(); + session.evaluate(57.5, instance.getTotal()); + session.evaluate(57.5, result); + + // heck values reduced to the dimensions of weight + op = + instance.updateState( + tf.constant(new float[][][] {{{1.f, 2.f}, {3.f, 2.f}, {0.5f, 4.f}}}), + tf.constant(new double[] {0.5})); + session.run(op); + result = instance.result(); + session.evaluate(63.75, instance.getTotal()); + session.evaluate(63.75, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java new file mode 100644 index 00000000000..023796ba367 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TopKCategoricalAccuracyTest.java @@ -0,0 +1,103 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class TopKCategoricalAccuracyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCorrectness() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + Operand labels = tf.constant(new float[][] {{0, 0, 1}, {0, 1, 0}}); + Operand predictions = + tf.constant(new double[][] {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}}); + + Op update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(1., instance.result()); + + // With `k` < 5. + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted1", 1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + + // With `k` > 5. + labels = + tf.constant( + new float[][] { + {0, 0, 1, 0, 0, 0, 0}, + {0, 1, 0, 0, 0, 0, 0} + }); + predictions = + tf.constant( + new double[][] { + {0.5f, 0.9f, 0.1f, 0.7f, 0.6f, 0.5f, 0.4f}, + {0.05f, 0.95f, 0f, 0f, 0f, 0f, 0f} + }); + instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testUnweighted6", 6, 1001L, TFloat64.class); + session.run(instance.resetStates()); + update = instance.updateState(labels, predictions, null); + session.run(update); + session.evaluate(0.5, instance.result()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + TopKCategoricalAccuracy instance = + new TopKCategoricalAccuracy<>(tf, "TopK_testWeighted", 5, 1001L, TFloat64.class); + session.run(instance.resetStates()); + + Operand labels = + tf.constant( + new double[][] { + {1, 0, 2}, + {1, 0, 0}, + {0, 0, 1} + }); + Operand predictions = + tf.constant( + new double[][] { + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f}, + {0f, 0.9f, 0.1f} + }); + + Operand sampleWeight = tf.constant(new double[] {1, 0, 1}); + + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + session.evaluate(1., instance.result()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java new file mode 100644 index 00000000000..1a68c2ed8b8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TrueNegativesTest.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class TrueNegativesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + TrueNegatives instance = new TrueNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(3.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + TrueNegatives instance = new TrueNegatives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(4.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + TrueNegatives instance = + new TrueNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {2.f, 5.f, 7.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = tf.constant(new double[][] {{0.0, 2.0, 3.0, 5.0}}); + TrueNegatives instance = + new TrueNegatives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {5., 15., 23.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java new file mode 100644 index 00000000000..c22c1245d97 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/TruePositivesTest.java @@ -0,0 +1,141 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +public class TruePositivesTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + long[][] trueArray = { + {0, 1, 0, 1, 0}, {0, 0, 1, 1, 1}, + {1, 1, 1, 1, 0}, {0, 0, 0, 0, 1} + }; + + long[][] predArray = { + {0, 0, 1, 1, 0}, {1, 1, 1, 1, 1}, + {0, 1, 0, 1, 0}, {1, 1, 1, 1, 1} + }; + + double[] sampleWeightArray = {1., 1.5, 2., 2.5}; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + TruePositives instance = new TruePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + + session.evaluate(7.0, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = tf.constant(this.predArray); + Operand labels = tf.constant(this.trueArray); + Operand sampleWeight = tf.constant(this.sampleWeightArray); + TruePositives instance = new TruePositives<>(tf, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + + session.evaluate(12.0, result); + } + } + + @Test + public void testUnweightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + TruePositives instance = + new TruePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat32.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, null); + session.run(update); + Operand result = instance.result(); + float[] expected = new float[] {6.f, 3.f, 1.f}; + session.evaluate(expected, result); + } + } + + @Test + public void testWeightedWithThresholds() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Operand predictions = + tf.constant( + new float[][] { + {0.9f, 0.2f, 0.8f, 0.1f}, + {0.2f, 0.9f, 0.7f, 0.6f}, + {0.1f, 0.2f, 0.4f, 0.3f}, + {0f, 1f, 0.7f, 0.3f} + }); + Operand labels = + tf.constant( + new long[][] { + {0, 1, 1, 0}, + {1, 0, 0, 0}, + {0, 0, 0, 0}, + {1, 1, 1, 1} + }); + + Operand sampleWeight = tf.constant(37.); + TruePositives instance = + new TruePositives<>(tf, new float[] {0.15f, 0.5f, 0.85f}, 1001L, TFloat64.class); + session.run(instance.getInitializer()); + Op update = instance.updateState(labels, predictions, sampleWeight); + session.run(update); + Operand result = instance.result(); + double[] expected = new double[] {222., 111., 37.}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 63d666f8640..4330fa0aed7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -63,6 +63,7 @@ private void testValid( TestSession testSession, Ops tf, Operand weights, Operand values, Class type) { Op staticOp = MetricsHelper.assertBroadcastable(tf, weights, values); + testSession.run(staticOp); // dynamic test Operand weightsPlaceholder = tf.placeholder(type);