Skip to content

Metrics phase 2 #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 115 commits into from
Apr 22, 2021
Merged

Metrics phase 2 #222

merged 115 commits into from
Apr 22, 2021

Conversation

JimClarke5
Copy link
Contributor

This is Phase 2 of Metrics.

This includes metrics that are not built on losses.
The new classes are the xxxxAccuracy set of classes
and the Confusion Matrix set of classes; AUC, FalseNegatives, FalsePositives, PrecisionAtRecall, RecallAtPrecison, Recall, SensitivityAtSpecificity and SpecificityAtSensitivity.

This PR is based on the current master branch including metrics1, and the metrics/losses generic cleanup and is not dependent on other PRs.

@JimClarke5
Copy link
Contributor Author

@karllessard There seems to be something wrong with the build environment. This is the second PR that created this error on the quick build.

Failed to execute goal on project tensorflow-core-api: Could not resolve dependencies for project org.tensorflow:tensorflow-core-api:jar:0.3.0-SNAPSHOT: Could not find artifact org.tensorflow:tensorflow-core-api:jar:linux-x86_64:0.3.0-20210211.071905-248 in ossrh-snapshots (https://oss.sonatype.org/content/repositories/snapshots) -> [Help 1]

@rnett
Copy link
Contributor

rnett commented Feb 16, 2021

@JimClarke5 I'm seeing that as well when building, locally too.

Your list is missing Precision fyi, although it's implemented.

Would it be possible to add a ConfusionMatrix metric, and a metric to calculate all of the true/false positive/negative stats in the same metric?

@JimClarke5
Copy link
Contributor Author

JimClarke5 commented Feb 16, 2021

The core method for confusion metrics is MetricsHelper.updateConfusionMatrixVariables() and handles all four possibilities true and false positives and true and false negatives.
As far as a ConfusionMatrix metric, how would that be defined? sklearn.metrics.confusion_matrix()?

@rnett
Copy link
Contributor

rnett commented Feb 16, 2021

The core method for confusion metrics is MetricsHelper.updateConfusionMatrixVariables() and handles all four possibilities true and false positives and true and false negatives.

Yeah, I just want a way to do it by passing a metric to get all 4, rather than relying on internal code.

As far as a ConfusionMatrix metric, how would that be defined? sklearn.metrics.confusion_matrix()?

Sklearn's would work, I was thinking of PyTorch-Lightning's, but hey do the same thing (lightning bases theirs on https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix). Obviously it doesn't work with multidimensional outputs, but it's a helpful visual way to see which classes are causing errors in classification. And generating one after evaluation from the existing metrics is hard (if not impossible) since you lose the actual predictions.

@JimClarke5
Copy link
Contributor Author

JimClarke5 commented Feb 16, 2021

I see that TF has tf.math.confusion_matrix().

TF Addons hás tfa.metrics.MultiLabelConfusionMatrix

@rnett
Copy link
Contributor

rnett commented Feb 17, 2021

tfa.metrics.MultiLabelConfusionMatrix seems like a misnomer, it's not a confusion matrix, just the true/false positive/negatives. The docs aren't super clear though so I could be reading it wrong. tf.math.confusion_matrix() is exactly what I mean.

Copy link
Contributor

@deansher deansher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This review is very much work in progress. I'm still getting my brain around AUC and I haven't looked at the other metrics yet at all.

* <p>Usage: <br>
*
* <pre>
* AUC m = new getTF().keras.metrics.AUC( getTF(), 3);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broken by global replace

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I have fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to drop tf.keras.metrics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

* @param tf The TensorFlow Ops
* @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}.
* @param numThresholds the number of thresholds to use when discretizing the roc curve. Values
* must be &gt; 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DEFAULT_NUM_THRESHOLDS is not yet implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK added check for numThresholds being null, and if null set numThresholds to DEFAULT_NUM_THRESHOLDS

Copy link
Contributor Author

@JimClarke5 JimClarke5 Mar 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just updated this and change the ctor for numThresholds to int, so a value is required. Fixed the javadoc comment to remove reference to the default value. So if someone wants to use DEFAULT_NUM_THRESHOLDS, the ctor will need to pass that value, or use one of the ctors that defines using DEFAULT_NUM_THRESHOLDS.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

long seed,
Class<T> type) {
super(tf, name == null ? DEFAULT_NAME : name, seed);
this.truePositivesName = this.getVariableName(TRUE_POSITIVES);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extensive use of this seems odd to my eye.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old habits are hard to break, I will remove them where they are not necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done


if (thresholds != null) { // ignore numThresholds
for (float t : thresholds)
if (t < 0.0f || t > 1.0f)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our convention is to always use braces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

this.numThresholds = thresholds.length + 2;
Arrays.sort(thresholds);
} else {
if (numThresholds <= 1) throw new IllegalArgumentException("numThresholds must be > 1.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

braces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

}
if (this.isMultiLabel() || this.getLabelWeights() != null) {
List<SymbolicShape<? extends TNumber>> symbols = new ArrayList<>();
symbols.add(new SymbolicShape<>(lLabels, "N", "L"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constraint doesn't entirely match the subsequent logic: we later apply squeezeOrExpandDimensions to labels and predictions, presumably with the goal of allowing some dimensional relationships other than the one asserted here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Symbolic shapes was a challenge to figure out. What it boils down to is a set of rules to later on verify that the shapes at run time meet a set of rules. If the condition is multiLabel and has labelWeights, the shape of the labels must be Number of Examples ("N") and Number of Labels ("L").
Each variable then must be Number of Thresholds ("T") and Number of Labels ("L"). The shape of labelWeights, if not null, must be Number of Labels ("L"). assertShapes() adds a collection of assertThat operations as control Ops to do the realtime checking by substituting the real time values for "L", "T" and "N" to make sure all the Operands are consistent with these rules.

The rules get applied to labels, and the variables and labelWeights.

Later on in updateConfusionMatrixVariables(), the rules would be asserted through control dependencies. But updateConfusionMatrixVariables() knows nothing of that, so I am not sure how to avoid handling the squeeze/compress.

I think more code notes would be helpful, but TF Python is not very verbose, so I am just trying to mimic their code as best I can. I can only assume they had a reason for writing the code they way they did. Also, there is a necessary translation from Pyhton-ese to Java, that is not always straight forward. With that said, the more we can figure out what they intended to do the better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed updateConfusionMatrixVariables thresholds argument from float[] to Operand<TFloat32>.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

* topK is set)
* @param topK Optional, indicates that the positive labels should be limited to the top k
* predictions, may be null.
* @param classId Optional, limits the prediction and labels to the specified class
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be helpful to document that classId is an index in the second dimension of prediction and labels. (And would it be better named classIndex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

classId is not an index. It is an arbitrary integer representing 1 class of a set of classes in classification problems.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We later use it as an index:

if (classId != null) {
      lLabels = tf.squeeze(tf.gather(lLabels, tf.constant(new int[] {classId}), tf.constant(1)));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually classId is a column index into the multi-dim array, so this code is wrong, I will repair it.

y_true = y_true[..., class_id]
y_pred = y_pred[..., class_id]

This translate in a column slice for the column index matching class_id.
Perhaps classIndex would be a better name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed to classIndex, but in doing searches for Machine Learning classification, I found that classId (class_id), is the preferred term for this.

}

if (classId != null) {
lLabels = tf.squeeze(tf.gather(lLabels, tf.constant(new int[] {classId}), tf.constant(1)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this gather give us exactly the shape we need? Couldn't we omit both this squeeze and the subsequent expandDims?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the squeeze and now just do:
tLabels = tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(1));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

return initializers;
}

/** {@inheritDoc} */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that (if I'm understanding correctly) AUC doesn't support passing batched data into a single call to updateStateList, it would be helpful to give this method its own javadoc that explains the permitted dimensional relationships.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nothing in the AUC or metrics_utils.update_confusion_matrix_variables() Python source, that does special handling of the batch dimension. I don't see whether it is required or prohibited.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

tf.select(
tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)),
tf.constant(1),
tf.reduceProd(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code does extra work to accommodate prediction/label rank > 2, but it's not clear to me that our earlier code supports such a situation:

  • In the multi-label case, we've already asserted that our labels are rank 2.
  • Our handling of classId seems semantically odd if our predictions and labels have rank > 2. In that case, classId limits consideration to a slice of predictions/labels for each example rather than (as its name suggests) a single prediction/label for each example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only restriction is on multiLabel, rank must be 2.

Also, why couldn't class features be represented with multi-dimensional input? Class (classID) is just a vertical slice of the feature set, whether it be a 1D or an nD slice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 Although I heard you that classId is the normal name for this, calling it classIndex does make me feel a ton better about having it define an nD slice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done


Operand<T> predictionsExtraDim;
Operand<TBool> labelsExtraDim;
if (multiLabel) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be really helpful to document the shapes and semantics of the ...ExtraDim variables. By the time we get here, we started with prediction and label vectors that had both flexible and case-dependent shapes, and we've reshaped them multiple times.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ExtraDims are added so the operands of the tile operations later on are compatible. I will add a comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

Operand<TInt32> numThresholds;
Operand<TBool> oneThresh;
if (multiLabel) {
numThresholds = tf.shape.size(lLabels, tf.constant(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conditional definition of numThresholds is different from the Python implementation, which simply reads

  num_thresholds = thresholds.shape[0]

Operand<TBool> oneThresh;
if (multiLabel) {
numThresholds = tf.shape.size(lLabels, tf.constant(0));
oneThresh = tf.math.equal(tf.constant(1), tf.constant(thresholds.length));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different from the Python implementation, which checks whether the rank of thresholds is 1:

    one_thresh = math_ops.equal(
        math_ops.cast(1, dtype=dtypes.int32),
        array_ops.rank(thresholds),
        name='one_set_of_thresholds_cond')

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WRT: "This conditional definition of numThresholds is different from the Python implementation, which simply reads",
In TF Python, they convert thresholds to a tensor and get the first dimension value. Thresholds may be passed as a float, float[] or Operand.

thresholds = ops.convert_to_tensor_v2_with_dispatch(
      thresholds, dtype=variable_dtype)
  num_thresholds = thresholds.shape[0]

In Java, thresholds are passed as a float[], not an Operand. Creating an Operand and accessing its shape is unnecessary and is equivalent to the taking the thresholds.length. Also, because thresholds is not an Operand, the rank will always be 1.

I have not yet run into a case where the thresholds are a defined by a multi-dimesional Operand, but maybe we need to support that.

Nonetheless, the code, for now, should take thresholds.length instead of dimension 0 of lLabels. I will fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the thresholds argument to Operand so now the code reads as:

    Operand<TInt32> numThresholds = tf.shape.size(thresholds, tf.constant(0));
    Operand<TBool> oneThresh;
    if (multiLabel) {
      oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds));
    } else {
      // TODO handle Ragged Tensors????
      // [y_pred,
      //    y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
      //                                                   sampleWeights)
      oneThresh = tf.constant(true);
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to reshape numThresholds to Shape.scalar() so that is compatible with the tf.stack() operations later on.
Otherwise, I would get internal TF errors when the operations where evaluated.

Operand<TInt32> numThresholds =
        tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar()));
    Operand<TInt32> threshTilesShape = tf.stack(threshTiles);
    Operand<TInt32> stackedTiles = tf.stack(dataTiles);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

}

if (this.multiLabel) {
this.numLabels = null;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of setting numLabels to null here? I don't see the Python code doing that. It doesn't matter in our code -- we won't access numLabels again. But it may trip us up later, such as if we implement a counterpart to Python's reset_states.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition of multilLabel + labelWeights != null is handled in
AUC.result(), not in MetricsHelper.updateConfusionMatrixVariables().

However, you are right that this.labelWeights should not be changed as it is used later. I will fix the code to pass null to updateConfusionMatrixVariables if multiLabel is true.

MetricsHelper.updateConfusionMatrixVariables(
            tf,
            confusionMatrix,
            varInitializers,
            tLabels,
            tPredictions,
            tf.constant(thresholds),
            null,
            null,
            tSampleWeights,
            isMultiLabel(),
            isMultiLabel() ? null : labelWeights));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

updateOperations.addAll(
MetricsHelper.assertShapes(getTF(), symbols, "Number of labels is not consistent."));
}
if (this.isMultiLabel()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I think setting labelWeights to null is wrong anyway. In the Python code, they don't change the instance variable to None. Rather, they create a local that may be None just for the purpose of invoking metrics_utils.update_confusion_matrix_variables:

    # Only forward label_weights to update_confusion_matrix_variables when
    # multi_label is False. Otherwise the averaging of individual label AUCs is
    # handled in AUC.result
    label_weights = None if self.multi_label else self.label_weights

Added comment on Operand<TInt32> numThresholds reshape to scalar.

Added comment to ExtraDims
* @param <T> the type of weights and values
* @throws IllegalArgumentException If static checks determine `weights` has incorrect shape.
*/
@SuppressWarnings("unchecked")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can eliminate the need for @SuppressWarnings("unchecked") by using a different idiom for constructing empty lists:

            .withControlDependencies(Collections.emptyList())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, changed Collections.emptyList() and removed @SuppressWarnings("unchecked")

Ops tf = getTF();
Operand<T> result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives));
return thresholds.length == 1
? tf.slice(
Copy link
Contributor

@deansher deansher Apr 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this accurately mirrors the Python. Also, I think the fact that our tests don't catch the problem shows a weakness in our test infrastructure.

Here's the Python code:

    return result[0] if len(self.thresholds) == 1 else result

In the case of a single threshold, it returns a scalar. But I believe this Java code returns shape (1).

Here's how we verify a return value in PrecisionTest:

      double expectedPrecision = weightedTP / weightedPositives;
      session.evaluate(expectedPrecision, precision);

This is intended to verify that precision is a scalar with value expectedPrecision. But here's the verification code in org.tensorflow.framework.metrics.PrecisionTest#testWeighted:

      try (TFloat32 result =
          (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) {
        result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon));
      }

It doesn't verify that result is a scalar. Rather, it verifies that every scalar in result has the expected value. (In fact, I think if result were a zero-length vector, the test would pass?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, Python returns scalar, I have change to:

return  thresholds.length == 1
        ? tf.reshape(tf.slice(
            result,
            tf.expandDims(tf.constant(0), tf.constant(0)),
            tf.expandDims(tf.constant(1), tf.constant(0))),
            tf.constant(Shape.scalar()))
        : result;

*
* @param tf the TensorFlow Ops
* @param name the name of the metric instance, if null then {@link Class#getSimpleName()} is used
* @param value A scalar value in range `[0, 1]`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On first reading, value is inscrutable. Perhaps extend its documentation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I've finished reading this class and thinking about it: perhaps value doesn't belong in this superclass? It has no functionality here, which I think makes it tough to name and explain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I have removed value from the abstract class

* @return a List of Operations to update the metric state.
*/
@Override
@SuppressWarnings("unchecked")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eliminate this by using Collections.emptyMap() instead of Collections.EMPTY_MAP in the body.


/** {@inheritDoc} */
@Override
public Op resetStates() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this pattern work as intended?

  • We create a single initializer Op.
  • We return this same Op instance from every call to resetStates.
  • Presumably, we expect our caller to form a control dependency on this Op?

But the Op will only be executed once, regardless of how many times callers form control dependencies on it, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initializer is an Assign Operation, and when run, will assign zeroes to the accumulator, which is the stored value.
So each time resetStates is run, the accumulator will be reset to zeroes.

This Op could be used with a control dependency on accumulator, but I elected to let the caller do that if wanted. The issue is that this Op should not be permanently attached to accumulator because if you are doing, for example, multiple training steps you do not want to always reset it to zero on each step, but rather add to the it. However, the caller may want to reset the value to zeros, perhaps between epochs.

In the test cases I call
session.run(instance.getInitializer());
which could have equally been
session.run(instance.resetStates());

MetricsHelper.updateConfusionMatrixVariables(
tf,
confusionMatrix,
varInitializers,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of arranging for updateConfusionMatrixVariables to form control dependencies on our variable initializers? Doesn't it inherently depend on our variable-creation operations, which also initialize the variables to zero?

Perhaps the intent is for updateConfusionMatrixVariables to depend on initializers corresponding to the most recent call to resetStates? But see my comment on that method for doubts on whether that would work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initializers only get attached to the confusion variables when the AUC updateStates is first run. After that the CM variables are not not reinitialized until resetStates is called. This is done this way because the CM variables cannot be created until build is called with the input shapes on the first invocation of updateStates. This is controlled by the boolean initialized member.


/** {@inheritDoc} */
@Override
public Op resetStates() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this method work as intended? It creates a new no-op operation, but that operation has dependencies on the same initializers operations we created on construction. Won't those initialization operations be executed exactly once (assuming our caller depends on this no-op operation), regardless of subsequent calls to this resetStates method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this Op is run through session.run(), the CM variables will be reset to zero. See the test case, testCummulative(), I just created.

     // test reset
      session.run(instance.resetStates());

float[] expectedThresholds = new float[] {-1e-7f, 0.5f, 1 + 1e-7f};
assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon);

instance.resetStates();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we expect this method to do? It creates a no-op operation that depends on our set of pre-constructed initializer operations. What effect will that have? What effect do we want from it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does nothing at this point, I just removed it.

try (TestSession session = TestSession.createTestSession(tfMode)) {
Ops tf = session.getTF();
Accuracy<TFloat32> instance = new Accuracy<>(tf, 1001L, TFloat32.class);
session.run(instance.resetStates());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What effect do we want from this call to session.run? Does this in some way mirror how we expect Accuracy to be used in production?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sets the variables to their initial state. The variables are total, and possibly count, in Reduce. I did not create these with an initializer so tf.init() won't work, so you have to run resetStates() to initialize the variables before your first attempt to add to them, thus the call to session.run(instance.resetStates()); The idea he is that the initializers may differ in sub-classes of Reduce.

I did it this way in case there were multiple metrics that get created at various times, as it seems tf.init() is akin to the big bang. What I don't know is does tf.init() clear any registered initializers or not. I guess I need to test this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I just tested tf.init() and it does NOT clear the initializers after it is run. So, it is a big bang kind of thing.


// new instance same graph
instance = new Accuracy<>(tf, 1001L, TFloat32.class);
session.run(instance.resetStates());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are we trying to test with this sequence of four calls to Session.run? Are we confident that the semantics of Session.run will precisely emulate some particular production usage? Or can we write this code more like production code would be written?

In particular, I'm interested in the effect of our second combination of Session.run with Accuracy.resetStates below:

        // reset variables
        session.run(instance.resetStates());

Does that have exactly the same semantics as forming a control dependency on the Op returned by resetStates, or does it force a fresh execution of that Op even if it was previously executed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A metric with the same name will pick up the variables created by a previous metric with that same name.
It is intentional.
This test tests that this functionality works.

For example, in Reduce

total = getTF().withName(totalName).variable(Shape.scalar(), resultType);

Returns the variable with the same name, totalName, from the Graph it it already exists. Otherwise , it creates a new Variable.

Op update = instance.updateState(yTrue, yPred, null);

for (int i = 0; i < 10; i++) {
session.run(update);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What effect do we expect from these repeated calls to Session.run? Does it exactly mirror a use of our API that we expect in production? For example, does every iteration force a fresh execution of the Op returned by updateState, or does it merely retrieve a value cached from the first execution? Is it the same as, or different from, a caller in production repeatedly forming control dependencies on the Op returned by updateState?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It re-executes the updateState operation. This particular test tests that the results do not change on each invocation.

In some other Class Tests, the results do change, especially if the updateState operation is doing something like using Random number generation. Sometimes this kind of behavior is buried deep in the TF raw Operations.

Operand<TFloat32> result = instance.result();

for (int i = 0; i < 10; i++) {
session.evaluate(result, instance.result());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, what effect do we expect from these repeated calls to session.evaluate? What would be the equivalent in production code? Are we confident that session.evaluate accurately mirrors that production equivalent? Is there a way for us to write our code more like production code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this is testing that the first fetch of result, does not change after subsequent result fetches. Again, different implementation of result logic might produce different results when result is evaluated due to hidden behavior (e.g. randomness) in the TF Operations.

add logic to accept a 1 item weightsShapeStatic.
change is_scalar to isScalar
Fix logic in hasValidDims to mactch Pyhton implementation.
add checks for sampleWeights rank matching lables rank
add checks on labels and predicitons to make sure they have the same number of elements.
…ase_2

# Conflicts:
#	tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java
#	tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java
#	tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
#	tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java
@JimClarke5
Copy link
Contributor Author

TF Python, tf.math.confusion_matrix() produces the same result as sklearn.metrics.confusion_matrix. However, in Java, this operation is not generated. In TF Python, this is implemented in tensorflow/python/ops/confusion_matrix.py. We would have to port this to Java.

@JimClarke5
Copy link
Contributor Author

FYI, I have implemented math.confusionMatrix in #255, FrameworkOps, and verified it against TF Python.

@karllessard
Copy link
Collaborator

Hi @deansher and @JimClarke5 , just checking here if this PR is ready to be merged, after collecting 320 of your comments :)

If not, do you think it would be worth it to set up a chat room or video call to discuss about it?

@JimClarke5
Copy link
Contributor Author

I am ready when, @deansher is. The next PR on layers won't be any easier wrt review.

Copy link
Contributor

@deansher deansher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree! @JimClarke5 worked long and hard on this code before PR'ing it. Then he and I worked long and hard on it in PR. It's good code. Time for it to live in master.


Operand<T> trueSlice = tf.slice(this.truePositives, minIndex, tf.constant(new int[] {1}));
Operand<T> falseSlice = tf.slice(this.falsePositives, minIndex, tf.constant(new int[] {1}));
return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to return a single-element vector containing the precision, rather than a scalar?

/**
* Accumulates true positive and false negative statistics.
*
* @param labels the labels The ground truth values, with the same dimensions as predictions. Will
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opens with a wording glitch

*/
@Override
@SuppressWarnings("unchecked")
public List<Op> updateStateList(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is almost identical to the implementation in Precision. Even if just for these two instances, perhaps that logic is long enough to be worth capturing in a superclass?

Operand<T> result =
tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives));
return this.thresholds.length == 1
? tf.slice(result, tf.constant(new int[] {0}), tf.constant(new int[1]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this slice operation returns a scalar as intended.

@deansher
Copy link
Contributor

Repeating this at the top level, where it's easier to see (and cc @karllessard): I agree! @JimClarke5 worked long and hard on this code before PR'ing it. Then he and I worked long and hard on it in PR. It's good code. Time for it to live in master.

@karllessard
Copy link
Collaborator

Ok great sounds good! Good job guys, really! I’ll let you figure out how you want to not lose track of any remaining opened comments/suggestions.

@karllessard karllessard merged commit 35b73ce into tensorflow:master Apr 22, 2021
@JimClarke5 JimClarke5 deleted the metrics_phase_2 branch May 2, 2021 15:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants