Skip to content

Metrics phase 2 #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 115 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
c57a2e7
Merge pull request #3 from tensorflow/master
JimClarke5 Oct 8, 2020
09fc07e
Merge pull request #4 from tensorflow/master
JimClarke5 Oct 27, 2020
a99dcb4
Merge pull request #5 from tensorflow/master
JimClarke5 Nov 17, 2020
ba294ea
Merge pull request #6 from tensorflow/master
JimClarke5 Nov 19, 2020
04f419a
Merge pull request #7 from tensorflow/master
JimClarke5 Dec 30, 2020
02e7ebf
Merge pull request #8 from tensorflow/master
JimClarke5 Jan 29, 2021
e0c9ed8
Merge pull request #9 from tensorflow/master
JimClarke5 Feb 1, 2021
b29edfd
Simplify generic parameters across losses and metrics.
JimClarke5 Feb 1, 2021
d7f7e4c
Reformat code
JimClarke5 Feb 1, 2021
6b4149c
Change order of TrainOps and QuantiQuantizationOps. For some reason, …
JimClarke5 Feb 1, 2021
e486a90
Fix LossMetric to change abstract "call" method to use gneric paramet…
JimClarke5 Feb 3, 2021
c711532
Reformat code, fix javadoc
JimClarke5 Feb 6, 2021
1dfb7c3
Update with new generic parameters
JimClarke5 Feb 3, 2021
a3aa3c4
Reformat code, fix javadoc
JimClarke5 Feb 7, 2021
5b0374b
Merge pull request #10 from tensorflow/master
JimClarke5 Feb 11, 2021
214140a
Merge branch 'master' into metrics_phase_2
JimClarke5 Feb 16, 2021
e038bbd
Merge pull request #11 from tensorflow/master
JimClarke5 Feb 23, 2021
2c93884
Change thresholds to Operand<TFloat32>
JimClarke5 Mar 2, 2021
616ebb2
change classId to classIndex
JimClarke5 Mar 2, 2021
af69e00
fix spurious "this.".
JimClarke5 Mar 2, 2021
def3051
Merge pull request #13 from tensorflow/master
JimClarke5 Mar 3, 2021
05c3d88
Remove references to keras in javadoc.
JimClarke5 Mar 3, 2021
eb8b2f8
Merge branch 'master' of https://github.com/JimClarke5/java into metr…
JimClarke5 Mar 3, 2021
dcfaa0f
Fix javadoc
JimClarke5 Mar 3, 2021
de9bc10
Reformat code and fix labelWeights argument in call to updateConfusio…
JimClarke5 Mar 3, 2021
4520294
Reformat code add code comments and change update_xx (update_fn) to u…
JimClarke5 Mar 3, 2021
a0b7041
Added javadocs and internal docs to AUC.java and MetricsHelper.java
deansher Mar 3, 2021
4dce2cf
Added internal docs to MetricsHelper.java
deansher Mar 5, 2021
7274cf5
Improved internal docs in MetricsHelper.java
deansher Mar 6, 2021
5e8fac6
Cleanup of updateConfusionMatrixVariables with variable name changes …
JimClarke5 Mar 10, 2021
9ea129e
Reformat code
JimClarke5 Mar 10, 2021
95c5b0c
Merge branch 'metrics_phase_2' into metrics_phase_2
JimClarke5 Mar 10, 2021
86c50de
Merge pull request #12 from deansher/metrics_phase_2
JimClarke5 Mar 10, 2021
9cb4cc0
Fix JavaDoc for enumerations
JimClarke5 Mar 11, 2021
7503bbc
Fix JavaDoc to emphasize that this does not inherit from Tensor.
JimClarke5 Mar 11, 2021
19c881c
Fix 'import *'
JimClarke5 Mar 11, 2021
ebecb5e
Fix casts
JimClarke5 Mar 11, 2021
478441e
Reformat code
JimClarke5 Mar 12, 2021
e9bea47
Reformat code
JimClarke5 Mar 12, 2021
8ae78cd
Fix javadoc change >= to &ge;
JimClarke5 Mar 12, 2021
b56344f
Revised and improved internal docs in MetricsHelper.java
deansher Mar 18, 2021
b4373dc
Fix spelling in JavaDoc
JimClarke5 Mar 18, 2021
99d2610
Change assertShapes to use runtime sizes as Operands rather than use …
JimClarke5 Mar 18, 2021
e379582
Tweaked internal docs in MetricsHelper.java
deansher Mar 19, 2021
96dce5c
Replace calls to tf.slice with private method slice to clean up code.
JimClarke5 Mar 21, 2021
9f4044a
Fix Javdoc, remove spurious y_pred.
JimClarke5 Mar 21, 2021
5e50e95
remove spurious cast
JimClarke5 Mar 21, 2021
81fc9fd
correct comments for enums
JimClarke5 Mar 21, 2021
e63fc29
Merge pull request #14 from deansher/metrics_phase_2
JimClarke5 Mar 21, 2021
11748ae
Merge pull request #15 from tensorflow/master
JimClarke5 Mar 21, 2021
743f416
Fix the documentation on TP, FP, TN, and FN
JimClarke5 Mar 21, 2021
5ed35f3
Added code comments to fitlerTopK.
JimClarke5 Mar 21, 2021
929ce7e
JavaDoc fixes and code cleanup and add code comments
JimClarke5 Apr 1, 2021
1b36693
JavaDoc fixes and code cleanup and add code comments
JimClarke5 Apr 1, 2021
d301b01
Fix code in sparseTopKCategoricalAccuracy to reshape to proper dimens…
JimClarke5 Apr 1, 2021
020eb9c
Fix JavaDoc
JimClarke5 Apr 1, 2021
3ae642d
Fix JavaDoc
JimClarke5 Apr 1, 2021
8b881b4
Fixed Javadoc, mainly to add shape requirements.
JimClarke5 Apr 1, 2021
001f051
Fixed Javadoc errors.
JimClarke5 Apr 1, 2021
a9412ea
Merge pull request #16 from tensorflow/master
JimClarke5 Apr 9, 2021
088e0d8
Simplify generic parameters across losses and metrics.
JimClarke5 Feb 1, 2021
4add028
Reformat code
JimClarke5 Feb 1, 2021
c23163c
Change order of TrainOps and QuantiQuantizationOps. For some reason, …
JimClarke5 Feb 1, 2021
7bd1fcf
Fix LossMetric to change abstract "call" method to use gneric paramet…
JimClarke5 Feb 3, 2021
74a548f
Reformat code, fix javadoc
JimClarke5 Feb 6, 2021
0eab19c
Update with new generic parameters
JimClarke5 Feb 3, 2021
f258e38
Reformat code, fix javadoc
JimClarke5 Feb 7, 2021
36e0934
Change thresholds to Operand<TFloat32>
JimClarke5 Mar 2, 2021
60d513d
change classId to classIndex
JimClarke5 Mar 2, 2021
ea2e3b1
fix spurious "this.".
JimClarke5 Mar 2, 2021
ca9e395
Remove references to keras in javadoc.
JimClarke5 Mar 3, 2021
817430f
Fix javadoc
JimClarke5 Mar 3, 2021
b440c63
Reformat code and fix labelWeights argument in call to updateConfusio…
JimClarke5 Mar 3, 2021
021df65
Reformat code add code comments and change update_xx (update_fn) to u…
JimClarke5 Mar 3, 2021
4df4a80
Added javadocs and internal docs to AUC.java and MetricsHelper.java
deansher Mar 3, 2021
99df6b4
Added internal docs to MetricsHelper.java
deansher Mar 5, 2021
0710ffe
Improved internal docs in MetricsHelper.java
deansher Mar 6, 2021
7e61ba2
Cleanup of updateConfusionMatrixVariables with variable name changes …
JimClarke5 Mar 10, 2021
efd7d43
Reformat code
JimClarke5 Mar 10, 2021
5e907df
Fix JavaDoc for enumerations
JimClarke5 Mar 11, 2021
8856c9c
Fix JavaDoc to emphasize that this does not inherit from Tensor.
JimClarke5 Mar 11, 2021
b533b2e
Fix 'import *'
JimClarke5 Mar 11, 2021
21029a7
Fix casts
JimClarke5 Mar 11, 2021
e154453
Reformat code
JimClarke5 Mar 12, 2021
eb0a7e6
Reformat code
JimClarke5 Mar 12, 2021
a29f8e9
Fix javadoc change >= to &ge;
JimClarke5 Mar 12, 2021
d47c3b8
Fix spelling in JavaDoc
JimClarke5 Mar 18, 2021
7f46673
Change assertShapes to use runtime sizes as Operands rather than use …
JimClarke5 Mar 18, 2021
41fde65
Replace calls to tf.slice with private method slice to clean up code.
JimClarke5 Mar 21, 2021
3c7e3a7
Fix Javdoc, remove spurious y_pred.
JimClarke5 Mar 21, 2021
4a114be
remove spurious cast
JimClarke5 Mar 21, 2021
da72efd
correct comments for enums
JimClarke5 Mar 21, 2021
fb1ab3a
Revised and improved internal docs in MetricsHelper.java
deansher Mar 18, 2021
b2937cd
Tweaked internal docs in MetricsHelper.java
deansher Mar 19, 2021
d9c8352
Fix the documentation on TP, FP, TN, and FN
JimClarke5 Mar 21, 2021
ce07c25
Added code comments to fitlerTopK.
JimClarke5 Mar 21, 2021
e9f1a35
JavaDoc fixes and code cleanup and add code comments
JimClarke5 Apr 1, 2021
b176432
JavaDoc fixes and code cleanup and add code comments
JimClarke5 Apr 1, 2021
d087a6f
Fix code in sparseTopKCategoricalAccuracy to reshape to proper dimens…
JimClarke5 Apr 1, 2021
6aec4ff
Fix JavaDoc
JimClarke5 Apr 1, 2021
9090e31
Fix JavaDoc
JimClarke5 Apr 1, 2021
4e5906c
Fixed Javadoc, mainly to add shape requirements.
JimClarke5 Apr 1, 2021
980bb64
Fixed Javadoc errors.
JimClarke5 Apr 1, 2021
b91cabf
Use zero operand for initializing falsePositives
JimClarke5 Apr 12, 2021
56b0300
Usefix javadoc
JimClarke5 Apr 12, 2021
f1203aa
fix javadoc
JimClarke5 Apr 12, 2021
dedaede
add axis to tf.gather, and tf.squeeze results.
JimClarke5 Apr 12, 2021
fc0f9be
Merge remote-tracking branch 'origin/metrics_phase_2' into metrics_ph…
JimClarke5 Apr 12, 2021
0974038
Remove unnecessary toStrings() on Shapes.
JimClarke5 Apr 16, 2021
a3deb5c
remove value from SensititySpecificityBase,
JimClarke5 Apr 18, 2021
8cdc776
fix result to reshape slice to scalar.
JimClarke5 Apr 18, 2021
0a19b80
Change Collections.EMPTY_LIST to Collections.emtpyList(),
JimClarke5 Apr 18, 2021
82dbbde
Add test testCummulative(), to make sure multiple calls were adding t…
JimClarke5 Apr 18, 2021
47af116
Fix typo in testCumulative method name
JimClarke5 Apr 18, 2021
1f40a81
remove print statememt
JimClarke5 Apr 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,021 changes: 1,021 additions & 0 deletions tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/AUC.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.metrics;

/**
* Specifies the type of the curve to be computed, {@link #ROC} for a Receiver Operator
* Characteristic curve [default] or {@link #PR} for a Precision-Recall-curve.
*/
public enum AUCCurve {
/** Receiver Operator Characteristic curve */
ROC,
/** Precision-Recall-curve */
PR;

/**
* Gets the AUCCurve enum value by name, regardless of case
*
* @param name the name of the AUCCurve enum value.
* @return the AUCCurve enum value.
*/
public AUCCurve get(String name) {
return AUCCurve.valueOf(name.toUpperCase());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.metrics;

/**
* Specifies the Riemann summation method used. {@link #INTERPOLATION} (default) applies mid-point
* summation scheme for ROC. For PR-AUC, interpolates (true/false) positives but not the ratio that
* is precision (see Davis &amp; Goadrich 2006 for details); {@link #MINORING} applies left
* summation for increasing intervals and right summation for decreasing intervals; {@link
* #MAJORING} does the opposite.
*
* @see <a href="https://www.biostat.wisc.edu/~page/rocpr.pdf">Davis &amp; Goadrich. 2006</a>
* @see <a href="https://en.wikipedia.org/wiki/Riemann_sum">Riemann summation method</a>
*/
public enum AUCSummationMethod {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The javadoc describing the enum constants should probably be on the constants themselves.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

INTERPOLATION,
MAJORING,
MINORING;

/**
* Gets the AUCSummationMethod enum value by name, regardless of case
*
* @param name the name of the AUCSummationMethod enum value.
* @return the AUCSummationMethod enum value.
*/
public AUCSummationMethod get(String name) {
return AUCSummationMethod.valueOf(name.toUpperCase());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.metrics;

import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.LossTuple;
import org.tensorflow.framework.metrics.impl.LossMetric;
import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
import org.tensorflow.framework.metrics.impl.MetricsHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

import static org.tensorflow.framework.utils.CastHelper.cast;

/**
* Metric that calculates how often predictions equals labels.
*
* <p>This metric creates two local variables, total and count that are used to compute the
* frequency with which <code>predictions</code> matches <code>labels</code>. This frequency is
* ultimately returned as binary accuracy: an idempotent operation that simply divides total by
* count.
*
* <p>If sampleWeights is <code>null</code>, weights default to 1. Use sampleWeights of 0 to mask
* values.
*
* @param <T> The data type for the metric result
*/
public class Accuracy<T extends TNumber> extends MeanMetricWrapper<T> implements LossMetric<T> {

/**
* Creates an Accuracy Metric using {@link Class#getSimpleName()} for the metric name
*
* @param tf the TensorFlow Ops
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the data type for the variables
*/
public Accuracy(Ops tf, long seed, Class<T> type) {
this(tf, null, seed, type);
}

/**
* Creates an Accuracy Metric
*
* @param tf the TensorFlow Ops
* @param name the name of the metric, if null then {@link Class#getSimpleName()} is used
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the data type for the variables
*/
public Accuracy(Ops tf, String name, long seed, Class<T> type) {
super(tf, name, seed, type);
setLoss(this);
}

/** {@inheritDoc} */
@Override
public Operand<T> call(
Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions) {
Operand<T> tLabels = cast(getTF(), labels, getResultType());
Operand<T> tPredictions = cast(getTF(), predictions, getResultType());
LossTuple<T> tuple =
MetricsHelper.raggedAssertCompatibleAndGetFlatValues(getTF(), tLabels, tPredictions);
tLabels = tuple.getLabels();
tPredictions = tuple.getTarget();

if (!predictions.shape().isCompatibleWith(labels.shape())) {
throw new IllegalArgumentException(
String.format(
"Shapes %s and %s are incompatible",
predictions.shape().toString(), labels.shape().toString()));
}

// cast TBool to result type
return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.metrics;

import org.tensorflow.Operand;
import org.tensorflow.framework.metrics.impl.LossMetric;
import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

import static org.tensorflow.framework.utils.CastHelper.cast;

/**
* Metric that calculates how often predictions matches binary labels.
*
* <p>This metric creates two local variables, total and count that are used to compute the
* frequency with which <code>predictions</code> matches <code>labels</code>. This frequency is
* ultimately returned as binary accuracy: an idempotent operation that simply divides total by
* count.
*
* <p>If sampleWeights is <code>null</code>, weights default to 1. Use sampleWeights of 0 to mask
* values.
*
* @param <T> The data type for the metric result
*/
public class BinaryAccuracy<T extends TNumber> extends MeanMetricWrapper<T>
implements LossMetric<T> {
/** the default threshold value for deciding whether prediction values are 1 or 0 */
public static final float DEFAULT_THRESHOLD = 0.5f;

/** the threshold value for deciding whether prediction values are 1 or 0 */
private final float threshold;

/**
* Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name and
* {@link #DEFAULT_THRESHOLD} for the threshold value.
*
* @param tf the TensorFlow Ops
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the data type for the variables
*/
public BinaryAccuracy(Ops tf, long seed, Class<T> type) {
this(tf, null, DEFAULT_THRESHOLD, seed, type);
}

/**
* Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name
*
* @param tf the TensorFlow Ops
* @param threshold a threshold for deciding whether prediction values are 1 or 0
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the data type for the variables
*/
public BinaryAccuracy(Ops tf, float threshold, long seed, Class<T> type) {
this(tf, null, threshold, seed, type);
}

/**
* Creates a BinaryAccuracy Metric
*
* @param tf the TensorFlow Ops
* @param name the name of the metric, if null then {@link Class#getSimpleName()} is used
* @param threshold a threshold for deciding whether prediction values are 1 or 0
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the data type for the variables
*/
public BinaryAccuracy(Ops tf, String name, float threshold, long seed, Class<T> type) {
super(tf, name, seed, type);
this.threshold = threshold;
setLoss(this);
}

/** {@inheritDoc} */
@Override
public Operand<T> call(
Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions) {

Operand<T> tPredictions = cast(getTF(), predictions, getResultType());
Operand<T> thresholdCast = cast(getTF(), getTF().constant(threshold), getResultType());
tPredictions =
cast(getTF(), getTF().math.greater(tPredictions, thresholdCast), getResultType());
Operand<T> tLabels = cast(getTF(), labels, getResultType());
return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================*/
package org.tensorflow.framework.metrics;

import org.tensorflow.Operand;
import org.tensorflow.framework.metrics.impl.LossMetric;
import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.OneHot;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TNumber;

import static org.tensorflow.framework.utils.CastHelper.cast;

/**
* Metric that calculates how often predictions matches one-hot labels.
*
* <p>You can provide <code>logits</code> of classes as <code>predictions</code>y_pred, since argmax
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a leftover "y_pred" here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

* of <code>logits</code> and probabilities are same.
*
* <p>This metric creates two local variables, <code>total</code> and <code>count</code> that are
* used to compute the frequency with which <code>predictions</code> matches <code>labels</code>.
* This frequency is ultimately returned as categorical accuracy: an idempotent operation that
* simply divides total by count.
*
* <p><code>predictions</code> and <code>labels</code> should be passed in as vectors of
* probabilities, rather than as labels. If necessary, use {@link
* org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand
* <code>labels</code> as a vector.
*
* <p>If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values.
*
* @param <T> The data type for the metric result
*/
public class CategoricalAccuracy<T extends TNumber> extends MeanMetricWrapper<T>
implements LossMetric<T> {

/**
* Creates a CategoricalAccuracy metric, using {@link Class#getSimpleName()} for the metric name
*
* @param tf the TensorFlow Ops
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the data type for the variables
*/
public CategoricalAccuracy(Ops tf, long seed, Class<T> type) {
this(tf, null, seed, type);
}

/**
* Creates a CategoricalAccuracy metric
*
* @param tf the TensorFlow Ops
* @param name the name of the metric, if null then {@link Class#getSimpleName()} is used
* @param seed the seed for random number generation. An initializer created with a given seed
* will always produce the same random tensor for a given shape and data type.
* @param type the data type for the variables
*/
public CategoricalAccuracy(Ops tf, String name, long seed, Class<T> type) {
super(tf, name, seed, type);
super.setLoss(this);
}

/** {@inheritDoc} */
@Override
public Operand<T> call(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want these parameters to be Operand<? extends T>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The <T> is the internal type for the Metric and the result is that type, Class<T>, while the Operands for call are <? extends TNumber>. This allows the operands to vary somewhat, but the result is based on the type for the metric class itself.
This allows the return value to be use in setting the Metric's Variables which are type Variable<T>, defined in Reduce. Baring this, the result would still have to be cast to the Class<T>.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 makes sense

Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions) {
Operand<TInt64> trueMax = getTF().math.argMax(labels, getTF().constant(-1));

Operand<TInt64> predMax = getTF().math.argMax(predictions, getTF().constant(-1));
return cast(getTF(), getTF().math.equal(trueMax, predMax), getResultType());
}
}
Loading