-
Notifications
You must be signed in to change notification settings - Fork 214
Metrics phase 2 #222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metrics phase 2 #222
Conversation
Sync with master tensorflow on upstream
Merge main branch to local branch
Update after losses merge
Fix Javadoc errors (tensorflow#152)
pull type def
Metrics Phase 1 (tensorflow#180)
…when I build it reverses these 2 from master's version.
…er for predictions instead of <T>.
Pull latest tensorflow master
@karllessard There seems to be something wrong with the build environment. This is the second PR that created this error on the quick build.
|
@JimClarke5 I'm seeing that as well when building, locally too. Your list is missing Would it be possible to add a |
The core method for confusion metrics is |
Yeah, I just want a way to do it by passing a metric to get all 4, rather than relying on internal code.
Sklearn's would work, I was thinking of PyTorch-Lightning's, but hey do the same thing (lightning bases theirs on https://scikit-learn.org/stable/modules/model_evaluation.html#confusion-matrix). Obviously it doesn't work with multidimensional outputs, but it's a helpful visual way to see which classes are causing errors in classification. And generating one after evaluation from the existing metrics is hard (if not impossible) since you lose the actual predictions. |
I see that TF has TF Addons hás |
|
Merge with latest
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This review is very much work in progress. I'm still getting my brain around AUC
and I haven't looked at the other metrics yet at all.
* <p>Usage: <br> | ||
* | ||
* <pre> | ||
* AUC m = new getTF().keras.metrics.AUC( getTF(), 3); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
broken by global replace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I have fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need to drop tf.keras.metrics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
* @param tf The TensorFlow Ops | ||
* @param name the name of the metric, if name is null then use {@link #DEFAULT_NAME}. | ||
* @param numThresholds the number of thresholds to use when discretizing the roc curve. Values | ||
* must be > 1. if null, the default is {@link #DEFAULT_NUM_THRESHOLDS} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DEFAULT_NUM_THRESHOLDS
is not yet implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK added check for numThresholds
being null, and if null set numThresholds
to DEFAULT_NUM_THRESHOLDS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just updated this and change the ctor for numThresholds
to int
, so a value is required. Fixed the javadoc comment to remove reference to the default value. So if someone wants to use DEFAULT_NUM_THRESHOLDS
, the ctor will need to pass that value, or use one of the ctors that defines using DEFAULT_NUM_THRESHOLDS
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
long seed, | ||
Class<T> type) { | ||
super(tf, name == null ? DEFAULT_NAME : name, seed); | ||
this.truePositivesName = this.getVariableName(TRUE_POSITIVES); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extensive use of this
seems odd to my eye.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Old habits are hard to break, I will remove them where they are not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
|
||
if (thresholds != null) { // ignore numThresholds | ||
for (float t : thresholds) | ||
if (t < 0.0f || t > 1.0f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think our convention is to always use braces?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
this.numThresholds = thresholds.length + 2; | ||
Arrays.sort(thresholds); | ||
} else { | ||
if (numThresholds <= 1) throw new IllegalArgumentException("numThresholds must be > 1."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
braces?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
} | ||
if (this.isMultiLabel() || this.getLabelWeights() != null) { | ||
List<SymbolicShape<? extends TNumber>> symbols = new ArrayList<>(); | ||
symbols.add(new SymbolicShape<>(lLabels, "N", "L")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This constraint doesn't entirely match the subsequent logic: we later apply squeezeOrExpandDimensions
to labels and predictions, presumably with the goal of allowing some dimensional relationships other than the one asserted here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Symbolic shapes was a challenge to figure out. What it boils down to is a set of rules to later on verify that the shapes at run time meet a set of rules. If the condition is multiLabel
and has labelWeights
, the shape of the labels must be Number of Examples ("N")
and Number of Labels ("L")
.
Each variable then must be Number of Thresholds ("T")
and Number of Labels ("L")
. The shape of labelWeights
, if not null, must be Number of Labels ("L")
. assertShapes()
adds a collection of assertThat
operations as control Ops to do the realtime checking by substituting the real time values for "L", "T" and "N" to make sure all the Operands are consistent with these rules.
The rules get applied to labels, and the variables and labelWeights.
Later on in updateConfusionMatrixVariables()
, the rules would be asserted through control dependencies. But updateConfusionMatrixVariables()
knows nothing of that, so I am not sure how to avoid handling the squeeze/compress.
I think more code notes would be helpful, but TF Python is not very verbose, so I am just trying to mimic their code as best I can. I can only assume they had a reason for writing the code they way they did. Also, there is a necessary translation from Pyhton-ese to Java, that is not always straight forward. With that said, the more we can figure out what they intended to do the better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed updateConfusionMatrixVariables
thresholds
argument from float[]
to Operand<TFloat32>
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
* topK is set) | ||
* @param topK Optional, indicates that the positive labels should be limited to the top k | ||
* predictions, may be null. | ||
* @param classId Optional, limits the prediction and labels to the specified class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be helpful to document that classId
is an index in the second dimension of prediction and labels. (And would it be better named classIndex
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
classId
is not an index. It is an arbitrary integer representing 1 class of a set of classes in classification problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We later use it as an index:
if (classId != null) {
lLabels = tf.squeeze(tf.gather(lLabels, tf.constant(new int[] {classId}), tf.constant(1)));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually classId is a column index into the multi-dim array, so this code is wrong, I will repair it.
y_true = y_true[..., class_id]
y_pred = y_pred[..., class_id]
This translate in a column slice for the column index matching class_id.
Perhaps classIndex
would be a better name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed to classIndex
, but in doing searches for Machine Learning classification, I found that classId
(class_id
), is the preferred term for this.
} | ||
|
||
if (classId != null) { | ||
lLabels = tf.squeeze(tf.gather(lLabels, tf.constant(new int[] {classId}), tf.constant(1))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this gather
give us exactly the shape we need? Couldn't we omit both this squeeze
and the subsequent expandDims
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the squeeze and now just do:
tLabels = tf.gather(tLabels, tf.constant(new int[] {classIndex}), tf.constant(1));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
return initializers; | ||
} | ||
|
||
/** {@inheritDoc} */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that (if I'm understanding correctly) AUC
doesn't support passing batched data into a single call to updateStateList
, it would be helpful to give this method its own javadoc that explains the permitted dimensional relationships.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is nothing in the AUC
or metrics_utils.update_confusion_matrix_variables()
Python source, that does special handling of the batch dimension. I don't see whether it is required or prohibited.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
tf.select( | ||
tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), | ||
tf.constant(1), | ||
tf.reduceProd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code does extra work to accommodate prediction/label rank > 2, but it's not clear to me that our earlier code supports such a situation:
- In the multi-label case, we've already asserted that our labels are rank 2.
- Our handling of
classId
seems semantically odd if our predictions and labels have rank > 2. In that case,classId
limits consideration to a slice of predictions/labels for each example rather than (as its name suggests) a single prediction/label for each example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only restriction is on multiLabel
, rank must be 2.
Also, why couldn't class features be represented with multi-dimensional input? Class (classID) is just a vertical slice of the feature set, whether it be a 1D or an nD slice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Although I heard you that classId
is the normal name for this, calling it classIndex
does make me feel a ton better about having it define an nD slice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
|
||
Operand<T> predictionsExtraDim; | ||
Operand<TBool> labelsExtraDim; | ||
if (multiLabel) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be really helpful to document the shapes and semantics of the ...ExtraDim
variables. By the time we get here, we started with prediction and label vectors that had both flexible and case-dependent shapes, and we've reshaped them multiple times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ExtraDims are added so the operands of the tile operations later on are compatible. I will add a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
Operand<TInt32> numThresholds; | ||
Operand<TBool> oneThresh; | ||
if (multiLabel) { | ||
numThresholds = tf.shape.size(lLabels, tf.constant(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conditional definition of numThresholds
is different from the Python implementation, which simply reads
num_thresholds = thresholds.shape[0]
Operand<TBool> oneThresh; | ||
if (multiLabel) { | ||
numThresholds = tf.shape.size(lLabels, tf.constant(0)); | ||
oneThresh = tf.math.equal(tf.constant(1), tf.constant(thresholds.length)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is different from the Python implementation, which checks whether the rank of thresholds is 1:
one_thresh = math_ops.equal(
math_ops.cast(1, dtype=dtypes.int32),
array_ops.rank(thresholds),
name='one_set_of_thresholds_cond')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WRT: "This conditional definition of numThresholds
is different from the Python implementation, which simply reads",
In TF Python, they convert thresholds to a tensor and get the first dimension value. Thresholds
may be passed as a float
, float[]
or Operand
.
thresholds = ops.convert_to_tensor_v2_with_dispatch(
thresholds, dtype=variable_dtype)
num_thresholds = thresholds.shape[0]
In Java, thresholds are passed as a float[]
, not an Operand
. Creating an Operand and accessing its shape is unnecessary and is equivalent to the taking the thresholds.length
. Also, because thresholds
is not an Operand, the rank will always be 1.
I have not yet run into a case where the thresholds
are a defined by a multi-dimesional Operand, but maybe we need to support that.
Nonetheless, the code, for now, should take thresholds.length
instead of dimension 0 of lLabels
. I will fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed the thresholds argument to Operand so now the code reads as:
Operand<TInt32> numThresholds = tf.shape.size(thresholds, tf.constant(0));
Operand<TBool> oneThresh;
if (multiLabel) {
oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds));
} else {
// TODO handle Ragged Tensors????
// [y_pred,
// y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true],
// sampleWeights)
oneThresh = tf.constant(true);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to reshape numThresholds
to Shape.scalar()
so that is compatible with the tf.stack()
operations later on.
Otherwise, I would get internal TF errors when the operations where evaluated.
Operand<TInt32> numThresholds =
tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar()));
Operand<TInt32> threshTilesShape = tf.stack(threshTiles);
Operand<TInt32> stackedTiles = tf.stack(dataTiles);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
} | ||
|
||
if (this.multiLabel) { | ||
this.numLabels = null; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of setting numLabels
to null
here? I don't see the Python code doing that. It doesn't matter in our code -- we won't access numLabels
again. But it may trip us up later, such as if we implement a counterpart to Python's reset_states
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition of multilLabel + labelWeights != null is handled in
AUC.result()
, not in MetricsHelper.updateConfusionMatrixVariables()
.
However, you are right that this.labelWeights
should not be changed as it is used later. I will fix the code to pass null to updateConfusionMatrixVariables
if multiLabel
is true.
MetricsHelper.updateConfusionMatrixVariables(
tf,
confusionMatrix,
varInitializers,
tLabels,
tPredictions,
tf.constant(thresholds),
null,
null,
tSampleWeights,
isMultiLabel(),
isMultiLabel() ? null : labelWeights));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 done
updateOperations.addAll( | ||
MetricsHelper.assertShapes(getTF(), symbols, "Number of labels is not consistent.")); | ||
} | ||
if (this.isMultiLabel()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I think setting labelWeights
to null
is wrong anyway. In the Python code, they don't change the instance variable to None
. Rather, they create a local that may be None
just for the purpose of invoking metrics_utils.update_confusion_matrix_variables
:
# Only forward label_weights to update_confusion_matrix_variables when
# multi_label is False. Otherwise the averaging of individual label AUCs is
# handled in AUC.result
label_weights = None if self.multi_label else self.label_weights
Added comment on Operand<TInt32> numThresholds reshape to scalar. Added comment to ExtraDims
Reformat code
Change all <code>xxxxxx</code> to {@code xxxxxx}
* @param <T> the type of weights and values | ||
* @throws IllegalArgumentException If static checks determine `weights` has incorrect shape. | ||
*/ | ||
@SuppressWarnings("unchecked") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can eliminate the need for @SuppressWarnings("unchecked")
by using a different idiom for constructing empty lists:
.withControlDependencies(Collections.emptyList())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, changed Collections.emptyList()
and removed @SuppressWarnings("unchecked")
Ops tf = getTF(); | ||
Operand<T> result = tf.math.divNoNan(truePositives, tf.math.add(truePositives, falsePositives)); | ||
return thresholds.length == 1 | ||
? tf.slice( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this accurately mirrors the Python. Also, I think the fact that our tests don't catch the problem shows a weakness in our test infrastructure.
Here's the Python code:
return result[0] if len(self.thresholds) == 1 else result
In the case of a single threshold, it returns a scalar. But I believe this Java code returns shape (1).
Here's how we verify a return value in PrecisionTest
:
double expectedPrecision = weightedTP / weightedPositives;
session.evaluate(expectedPrecision, precision);
This is intended to verify that precision
is a scalar with value expectedPrecision
. But here's the verification code in org.tensorflow.framework.metrics.PrecisionTest#testWeighted
:
try (TFloat32 result =
(TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) {
result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon));
}
It doesn't verify that result
is a scalar. Rather, it verifies that every scalar in result
has the expected value. (In fact, I think if result
were a zero-length vector, the test would pass?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, Python returns scalar, I have change to:
return thresholds.length == 1
? tf.reshape(tf.slice(
result,
tf.expandDims(tf.constant(0), tf.constant(0)),
tf.expandDims(tf.constant(1), tf.constant(0))),
tf.constant(Shape.scalar()))
: result;
* | ||
* @param tf the TensorFlow Ops | ||
* @param name the name of the metric instance, if null then {@link Class#getSimpleName()} is used | ||
* @param value A scalar value in range `[0, 1]` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On first reading, value
is inscrutable. Perhaps extend its documentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that I've finished reading this class and thinking about it: perhaps value
doesn't belong in this superclass? It has no functionality here, which I think makes it tough to name and explain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, I have removed value
from the abstract class
* @return a List of Operations to update the metric state. | ||
*/ | ||
@Override | ||
@SuppressWarnings("unchecked") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eliminate this by using Collections.emptyMap()
instead of Collections.EMPTY_MAP
in the body.
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Op resetStates() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this pattern work as intended?
- We create a single
initializer
Op
. - We return this same
Op
instance from every call toresetStates
. - Presumably, we expect our caller to form a control dependency on this
Op
?
But the Op
will only be executed once, regardless of how many times callers form control dependencies on it, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initializer
is an Assign
Operation, and when run, will assign zeroes to the accumulator
, which is the stored value.
So each time resetStates
is run, the accumulator
will be reset to zeroes.
This Op could be used with a control dependency on accumulator
, but I elected to let the caller do that if wanted. The issue is that this Op should not be permanently attached to accumulator
because if you are doing, for example, multiple training steps you do not want to always reset it to zero on each step, but rather add to the it. However, the caller may want to reset the value to zeros, perhaps between epochs.
In the test cases I call
session.run(instance.getInitializer());
which could have equally been
session.run(instance.resetStates());
MetricsHelper.updateConfusionMatrixVariables( | ||
tf, | ||
confusionMatrix, | ||
varInitializers, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of arranging for updateConfusionMatrixVariables
to form control dependencies on our variable initializers? Doesn't it inherently depend on our variable-creation operations, which also initialize the variables to zero?
Perhaps the intent is for updateConfusionMatrixVariables
to depend on initializers corresponding to the most recent call to resetStates
? But see my comment on that method for doubts on whether that would work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initializers only get attached to the confusion variables when the AUC updateStates
is first run. After that the CM variables are not not reinitialized until resetStates
is called. This is done this way because the CM variables cannot be created until build
is called with the input shapes on the first invocation of updateStates
. This is controlled by the boolean initialized
member.
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Op resetStates() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this method work as intended? It creates a new no-op operation, but that operation has dependencies on the same initializers
operations we created on construction. Won't those initialization operations be executed exactly once (assuming our caller depends on this no-op operation), regardless of subsequent calls to this resetStates
method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this Op is run through session.run(
), the CM variables will be reset to zero. See the test case, testCummulative()
, I just created.
// test reset
session.run(instance.resetStates());
float[] expectedThresholds = new float[] {-1e-7f, 0.5f, 1 + 1e-7f}; | ||
assertArrayEquals(expectedThresholds, instance.getThresholds(), epsilon); | ||
|
||
instance.resetStates(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do we expect this method to do? It creates a no-op operation that depends on our set of pre-constructed initializer operations. What effect will that have? What effect do we want from it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does nothing at this point, I just removed it.
try (TestSession session = TestSession.createTestSession(tfMode)) { | ||
Ops tf = session.getTF(); | ||
Accuracy<TFloat32> instance = new Accuracy<>(tf, 1001L, TFloat32.class); | ||
session.run(instance.resetStates()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What effect do we want from this call to session.run
? Does this in some way mirror how we expect Accuracy
to be used in production?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sets the variables to their initial state. The variables are total
, and possibly count
, in Reduce
. I did not create these with an initializer so tf.init()
won't work, so you have to run resetStates()
to initialize the variables before your first attempt to add to them, thus the call to session.run(instance.resetStates());
The idea he is that the initializers may differ in sub-classes of Reduce
.
I did it this way in case there were multiple metrics that get created at various times, as it seems tf.init()
is akin to the big bang. What I don't know is does tf.init()
clear any registered initializers or not. I guess I need to test this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I just tested tf.init()
and it does NOT clear the initializers after it is run. So, it is a big bang kind of thing.
|
||
// new instance same graph | ||
instance = new Accuracy<>(tf, 1001L, TFloat32.class); | ||
session.run(instance.resetStates()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are we trying to test with this sequence of four calls to Session.run
? Are we confident that the semantics of Session.run
will precisely emulate some particular production usage? Or can we write this code more like production code would be written?
In particular, I'm interested in the effect of our second combination of Session.run
with Accuracy.resetStates
below:
// reset variables
session.run(instance.resetStates());
Does that have exactly the same semantics as forming a control dependency on the Op
returned by resetStates
, or does it force a fresh execution of that Op
even if it was previously executed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A metric with the same name will pick up the variables created by a previous metric with that same name.
It is intentional.
This test tests that this functionality works.
For example, in Reduce
total = getTF().withName(totalName).variable(Shape.scalar(), resultType);
Returns the variable with the same name, totalName
, from the Graph it it already exists. Otherwise , it creates a new Variable
.
Op update = instance.updateState(yTrue, yPred, null); | ||
|
||
for (int i = 0; i < 10; i++) { | ||
session.run(update); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What effect do we expect from these repeated calls to Session.run
? Does it exactly mirror a use of our API that we expect in production? For example, does every iteration force a fresh execution of the Op
returned by updateState
, or does it merely retrieve a value cached from the first execution? Is it the same as, or different from, a caller in production repeatedly forming control dependencies on the Op
returned by updateState
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It re-executes the updateState operation. This particular test tests that the results do not change on each invocation.
In some other Class Tests, the results do change, especially if the updateState operation is doing something like using Random number generation. Sometimes this kind of behavior is buried deep in the TF raw Operations.
Operand<TFloat32> result = instance.result(); | ||
|
||
for (int i = 0; i < 10; i++) { | ||
session.evaluate(result, instance.result()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, what effect do we expect from these repeated calls to session.evaluate
? What would be the equivalent in production code? Are we confident that session.evaluate
accurately mirrors that production equivalent? Is there a way for us to write our code more like production code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again this is testing that the first fetch of result
, does not change after subsequent result
fetches. Again, different implementation of result logic might produce different results when result
is evaluated due to hidden behavior (e.g. randomness) in the TF Operations.
add logic to accept a 1 item weightsShapeStatic. change is_scalar to isScalar Fix logic in hasValidDims to mactch Pyhton implementation.
add checks for sampleWeights rank matching lables rank add checks on labels and predicitons to make sure they have the same number of elements.
…ase_2 # Conflicts: # tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java # tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Precision.java # tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java # tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java
TF Python, |
FYI, I have implemented |
fix sub-class CTORs.
remove SuppressWarning("unchecked");
…o the CM variables.
Hi @deansher and @JimClarke5 , just checking here if this PR is ready to be merged, after collecting 320 of your comments :) If not, do you think it would be worth it to set up a chat room or video call to discuss about it? |
I am ready when, @deansher is. The next PR on layers won't be any easier wrt review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree! @JimClarke5 worked long and hard on this code before PR'ing it. Then he and I worked long and hard on it in PR. It's good code. Time for it to live in master
.
|
||
Operand<T> trueSlice = tf.slice(this.truePositives, minIndex, tf.constant(new int[] {1})); | ||
Operand<T> falseSlice = tf.slice(this.falsePositives, minIndex, tf.constant(new int[] {1})); | ||
return tf.math.divNoNan(trueSlice, tf.math.add(trueSlice, falseSlice)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really want to return a single-element vector containing the precision, rather than a scalar?
/** | ||
* Accumulates true positive and false negative statistics. | ||
* | ||
* @param labels the labels The ground truth values, with the same dimensions as predictions. Will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
opens with a wording glitch
*/ | ||
@Override | ||
@SuppressWarnings("unchecked") | ||
public List<Op> updateStateList( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is almost identical to the implementation in Precision
. Even if just for these two instances, perhaps that logic is long enough to be worth capturing in a superclass?
Operand<T> result = | ||
tf.math.divNoNan(this.truePositives, tf.math.add(this.truePositives, this.falseNegatives)); | ||
return this.thresholds.length == 1 | ||
? tf.slice(result, tf.constant(new int[] {0}), tf.constant(new int[1])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this slice
operation returns a scalar as intended.
Repeating this at the top level, where it's easier to see (and cc @karllessard): I agree! @JimClarke5 worked long and hard on this code before PR'ing it. Then he and I worked long and hard on it in PR. It's good code. Time for it to live in master. |
Ok great sounds good! Good job guys, really! I’ll let you figure out how you want to not lose track of any remaining opened comments/suggestions. |
This is Phase 2 of Metrics.
This includes metrics that are not built on
losses
.The new classes are the
xxxxAccuracy
set of classesand the Confusion Matrix set of classes;
AUC
,FalseNegatives
,FalsePositives
,PrecisionAtRecall
,RecallAtPrecison
,Recall
,SensitivityAtSpecificity
andSpecificityAtSensitivity
.This PR is based on the current master branch including metrics1, and the metrics/losses generic cleanup and is not dependent on other PRs.