-
Notifications
You must be signed in to change notification settings - Fork 214
Metrics phase 2 #222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metrics phase 2 #222
Changes from 15 commits
c57a2e7
09fc07e
a99dcb4
ba294ea
04f419a
02e7ebf
e0c9ed8
b29edfd
d7f7e4c
6b4149c
e486a90
c711532
1dfb7c3
a3aa3c4
5b0374b
214140a
e038bbd
2c93884
616ebb2
af69e00
def3051
05c3d88
eb8b2f8
dcfaa0f
de9bc10
4520294
a0b7041
4dce2cf
7274cf5
5e8fac6
9ea129e
95c5b0c
86c50de
9cb4cc0
7503bbc
19c881c
ebecb5e
478441e
e9bea47
8ae78cd
b56344f
b4373dc
99d2610
e379582
96dce5c
9f4044a
5e50e95
81fc9fd
e63fc29
11748ae
743f416
5ed35f3
929ce7e
1b36693
d301b01
020eb9c
3ae642d
8b881b4
001f051
a9412ea
088e0d8
4add028
c23163c
7bd1fcf
74a548f
0eab19c
f258e38
36e0934
60d513d
ea2e3b1
ca9e395
817430f
b440c63
021df65
4df4a80
99df6b4
0710ffe
7e61ba2
efd7d43
5e907df
8856c9c
b533b2e
21029a7
e154453
eb0a7e6
a29f8e9
d47c3b8
7f46673
41fde65
3c7e3a7
4a114be
da72efd
fb1ab3a
b2937cd
d9c8352
ce07c25
e9f1a35
b176432
d087a6f
6aec4ff
9090e31
4e5906c
980bb64
b91cabf
56b0300
f1203aa
dedaede
fc0f9be
0974038
a3deb5c
8cdc776
0a19b80
82dbbde
47af116
1f40a81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.metrics; | ||
|
||
/** | ||
* Specifies the type of the curve to be computed, {@link #ROC} for a Receiver Operator | ||
* Characteristic curve [default] or {@link #PR} for a Precision-Recall-curve. | ||
*/ | ||
public enum AUCCurve { | ||
/** Receiver Operator Characteristic curve */ | ||
ROC, | ||
/** Precision-Recall-curve */ | ||
PR; | ||
|
||
/** | ||
* Gets the AUCCurve enum value by name, regardless of case | ||
* | ||
* @param name the name of the AUCCurve enum value. | ||
* @return the AUCCurve enum value. | ||
*/ | ||
public AUCCurve get(String name) { | ||
return AUCCurve.valueOf(name.toUpperCase()); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.metrics; | ||
|
||
/** | ||
* Specifies the Riemann summation method used. {@link #INTERPOLATION} (default) applies mid-point | ||
* summation scheme for ROC. For PR-AUC, interpolates (true/false) positives but not the ratio that | ||
* is precision (see Davis & Goadrich 2006 for details); {@link #MINORING} applies left | ||
* summation for increasing intervals and right summation for decreasing intervals; {@link | ||
* #MAJORING} does the opposite. | ||
* | ||
* @see <a href="https://www.biostat.wisc.edu/~page/rocpr.pdf">Davis & Goadrich. 2006</a> | ||
* @see <a href="https://en.wikipedia.org/wiki/Riemann_sum">Riemann summation method</a> | ||
*/ | ||
public enum AUCSummationMethod { | ||
INTERPOLATION, | ||
MAJORING, | ||
MINORING; | ||
|
||
/** | ||
* Gets the AUCSummationMethod enum value by name, regardless of case | ||
* | ||
* @param name the name of the AUCSummationMethod enum value. | ||
* @return the AUCSummationMethod enum value. | ||
*/ | ||
public AUCSummationMethod get(String name) { | ||
return AUCSummationMethod.valueOf(name.toUpperCase()); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.metrics; | ||
|
||
import org.tensorflow.Operand; | ||
import org.tensorflow.framework.losses.impl.LossTuple; | ||
import org.tensorflow.framework.metrics.impl.LossMetric; | ||
import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; | ||
import org.tensorflow.framework.metrics.impl.MetricsHelper; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
import static org.tensorflow.framework.utils.CastHelper.cast; | ||
|
||
/** | ||
* Metric that calculates how often predictions equals labels. | ||
* | ||
* <p>This metric creates two local variables, total and count that are used to compute the | ||
* frequency with which <code>predictions</code> matches <code>labels</code>. This frequency is | ||
* ultimately returned as binary accuracy: an idempotent operation that simply divides total by | ||
* count. | ||
* | ||
* <p>If sampleWeights is <code>null</code>, weights default to 1. Use sampleWeights of 0 to mask | ||
* values. | ||
* | ||
* @param <T> The data type for the metric result | ||
*/ | ||
public class Accuracy<T extends TNumber> extends MeanMetricWrapper<T> implements LossMetric<T> { | ||
|
||
/** | ||
* Creates an Accuracy Metric using {@link Class#getSimpleName()} for the metric name | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param seed the seed for random number generation. An initializer created with a given seed | ||
* will always produce the same random tensor for a given shape and data type. | ||
* @param type the data type for the variables | ||
*/ | ||
public Accuracy(Ops tf, long seed, Class<T> type) { | ||
this(tf, null, seed, type); | ||
} | ||
|
||
/** | ||
* Creates an Accuracy Metric | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param name the name of the metric, if null then {@link Class#getSimpleName()} is used | ||
* @param seed the seed for random number generation. An initializer created with a given seed | ||
* will always produce the same random tensor for a given shape and data type. | ||
* @param type the data type for the variables | ||
*/ | ||
public Accuracy(Ops tf, String name, long seed, Class<T> type) { | ||
super(tf, name, seed, type); | ||
setLoss(this); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Operand<T> call( | ||
Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions) { | ||
Operand<T> tLabels = cast(getTF(), labels, getResultType()); | ||
Operand<T> tPredictions = cast(getTF(), predictions, getResultType()); | ||
LossTuple<T> tuple = | ||
MetricsHelper.raggedAssertCompatibleAndGetFlatValues(getTF(), tLabels, tPredictions); | ||
tLabels = tuple.getLabels(); | ||
tPredictions = tuple.getTarget(); | ||
|
||
if (!predictions.shape().isCompatibleWith(labels.shape())) { | ||
throw new IllegalArgumentException( | ||
String.format( | ||
"Shapes %s and %s are incompatible", | ||
predictions.shape().toString(), labels.shape().toString())); | ||
} | ||
|
||
// cast TBool to result type | ||
return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.metrics; | ||
|
||
import org.tensorflow.Operand; | ||
import org.tensorflow.framework.metrics.impl.LossMetric; | ||
import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
import static org.tensorflow.framework.utils.CastHelper.cast; | ||
|
||
/** | ||
* Metric that calculates how often predictions matches binary labels. | ||
* | ||
* <p>This metric creates two local variables, total and count that are used to compute the | ||
* frequency with which <code>predictions</code> matches <code>labels</code>. This frequency is | ||
* ultimately returned as binary accuracy: an idempotent operation that simply divides total by | ||
* count. | ||
* | ||
* <p>If sampleWeights is <code>null</code>, weights default to 1. Use sampleWeights of 0 to mask | ||
* values. | ||
* | ||
* @param <T> The data type for the metric result | ||
*/ | ||
public class BinaryAccuracy<T extends TNumber> extends MeanMetricWrapper<T> | ||
implements LossMetric<T> { | ||
/** the default threshold value for deciding whether prediction values are 1 or 0 */ | ||
public static final float DEFAULT_THRESHOLD = 0.5f; | ||
|
||
/** the threshold value for deciding whether prediction values are 1 or 0 */ | ||
private final float threshold; | ||
|
||
/** | ||
* Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name and | ||
* {@link #DEFAULT_THRESHOLD} for the threshold value. | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param seed the seed for random number generation. An initializer created with a given seed | ||
* will always produce the same random tensor for a given shape and data type. | ||
* @param type the data type for the variables | ||
*/ | ||
public BinaryAccuracy(Ops tf, long seed, Class<T> type) { | ||
this(tf, null, DEFAULT_THRESHOLD, seed, type); | ||
} | ||
|
||
/** | ||
* Creates a BinaryAccuracy Metric using {@link Class#getSimpleName()} for the metric name | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param threshold a threshold for deciding whether prediction values are 1 or 0 | ||
* @param seed the seed for random number generation. An initializer created with a given seed | ||
* will always produce the same random tensor for a given shape and data type. | ||
* @param type the data type for the variables | ||
*/ | ||
public BinaryAccuracy(Ops tf, float threshold, long seed, Class<T> type) { | ||
this(tf, null, threshold, seed, type); | ||
} | ||
|
||
/** | ||
* Creates a BinaryAccuracy Metric | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param name the name of the metric, if null then {@link Class#getSimpleName()} is used | ||
* @param threshold a threshold for deciding whether prediction values are 1 or 0 | ||
* @param seed the seed for random number generation. An initializer created with a given seed | ||
* will always produce the same random tensor for a given shape and data type. | ||
* @param type the data type for the variables | ||
*/ | ||
public BinaryAccuracy(Ops tf, String name, float threshold, long seed, Class<T> type) { | ||
super(tf, name, seed, type); | ||
this.threshold = threshold; | ||
setLoss(this); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Operand<T> call( | ||
Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions) { | ||
|
||
Operand<T> tPredictions = cast(getTF(), predictions, getResultType()); | ||
Operand<T> thresholdCast = cast(getTF(), getTF().constant(threshold), getResultType()); | ||
tPredictions = | ||
cast(getTF(), getTF().math.greater(tPredictions, thresholdCast), getResultType()); | ||
Operand<T> tLabels = cast(getTF(), labels, getResultType()); | ||
return cast(getTF(), getTF().math.equal(tLabels, tPredictions), getResultType()); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
=======================================================================*/ | ||
package org.tensorflow.framework.metrics; | ||
|
||
import org.tensorflow.Operand; | ||
import org.tensorflow.framework.metrics.impl.LossMetric; | ||
import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.op.core.OneHot; | ||
import org.tensorflow.types.TInt64; | ||
import org.tensorflow.types.family.TNumber; | ||
|
||
import static org.tensorflow.framework.utils.CastHelper.cast; | ||
|
||
/** | ||
* Metric that calculates how often predictions matches one-hot labels. | ||
* | ||
* <p>You can provide <code>logits</code> of classes as <code>predictions</code>y_pred, since argmax | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a leftover "y_pred" here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 done |
||
* of <code>logits</code> and probabilities are same. | ||
* | ||
* <p>This metric creates two local variables, <code>total</code> and <code>count</code> that are | ||
* used to compute the frequency with which <code>predictions</code> matches <code>labels</code>. | ||
* This frequency is ultimately returned as categorical accuracy: an idempotent operation that | ||
* simply divides total by count. | ||
* | ||
* <p><code>predictions</code> and <code>labels</code> should be passed in as vectors of | ||
* probabilities, rather than as labels. If necessary, use {@link | ||
* org.tensorflow.op.Ops#oneHot(Operand, Operand, Operand, Operand, OneHot.Options...)} to expand | ||
* <code>labels</code> as a vector. | ||
* | ||
* <p>If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values. | ||
* | ||
* @param <T> The data type for the metric result | ||
*/ | ||
public class CategoricalAccuracy<T extends TNumber> extends MeanMetricWrapper<T> | ||
implements LossMetric<T> { | ||
|
||
/** | ||
* Creates a CategoricalAccuracy metric, using {@link Class#getSimpleName()} for the metric name | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param seed the seed for random number generation. An initializer created with a given seed | ||
* will always produce the same random tensor for a given shape and data type. | ||
* @param type the data type for the variables | ||
*/ | ||
public CategoricalAccuracy(Ops tf, long seed, Class<T> type) { | ||
this(tf, null, seed, type); | ||
} | ||
|
||
/** | ||
* Creates a CategoricalAccuracy metric | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param name the name of the metric, if null then {@link Class#getSimpleName()} is used | ||
* @param seed the seed for random number generation. An initializer created with a given seed | ||
* will always produce the same random tensor for a given shape and data type. | ||
* @param type the data type for the variables | ||
*/ | ||
public CategoricalAccuracy(Ops tf, String name, long seed, Class<T> type) { | ||
super(tf, name, seed, type); | ||
super.setLoss(this); | ||
} | ||
|
||
/** {@inheritDoc} */ | ||
@Override | ||
public Operand<T> call( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want these parameters to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 makes sense |
||
Operand<? extends TNumber> labels, Operand<? extends TNumber> predictions) { | ||
Operand<TInt64> trueMax = getTF().math.argMax(labels, getTF().constant(-1)); | ||
|
||
Operand<TInt64> predMax = getTF().math.argMax(predictions, getTF().constant(-1)); | ||
return cast(getTF(), getTF().math.equal(trueMax, predMax), getResultType()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The javadoc describing the enum constants should probably be on the constants themselves.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK