TFMA supports the following metrics:
- Standard keras metrics
(
tf.keras.metrics.*
)- Note that you do not need a keras model to use keras metrics. Metrics are computed outside of the graph in beam using the metrics classes directly.
- Standard TFMA metrics
(
tfma.metrics.*
) - Custom keras metrics (metrics derived from
tf.keras.metrics.Metric
) - Custom TFMA metrics (metrics derived from
tfma.metrics.Metric
using custom beam combiners or metrics derived from other metrics).
TFMA also provides built-in support for coverting binary classification metrics for use with multi-class/multi-label problems:
- Binarization based on class ID, top K, etc.
- Aggregated metrics based on micro averaging, macro averaging, etc.
TFMA also provides built-in support for query/ranking based metrics where the examples are grouped by a query key automatically in the pipeline.
Combined there are over 50+ standard metrics and plots available for a variety of problems including regression, binary classification, multi-class/multi-label classification, ranking, etc.
There are two ways to configure metrics in TFMA: (1) using the
MetricsSpec proto
or (2) by creating instances of tf.keras.metrics.*
and/or tfma.metrics.*
classes in python and using tfma.metrics.specs_from_metrics
to convert them to
MetricsSpecs.
The following sections describe example configurations for different types of machine learning problems.
The following is an example configuration setup for a regression problem.
Consult the tf.keras.metrics.*
and tfma.metrics.*
modules for possible
additional metrics supported.
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics { class_name: "ExampleCount" }
metrics { class_name: "WeightedExampleCount" }
metrics { class_name: "MeanSquaredError" }
metrics { class_name: "Accuracy" }
metrics { class_name: "MeanLabel" }
metrics { class_name: "MeanPrediction" }
metrics { class_name: "Calibration" }
metrics {
class_name: "CalibrationPlot"
config: '"min_value": 0, "max_value": 10'
}
}
""", tfma.EvalConfig()).metrics_specs
This same setup can be created using the following python code:
metrics = [
tfma.metrics.ExampleCount(name='example_count'),
tfma.metrics.WeightedExampleCount(name='weighted_example_count'),
tf.keras.metrics.MeanSquaredError(name='mse'),
tf.keras.metrics.Accuracy(name='accuracy'),
tfma.metrics.MeanLabel(name='mean_label'),
tfma.metrics.MeanPrediction(name='mean_prediction'),
tfma.metrics.Calibration(name='calibration'),
tfma.metrics.CalibrationPlot(
name='calibration', min_value=0, max_value=10)
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
Note that this setup is also avaliable by calling
tfma.metrics.default_regression_specs
.
The following is an example configuration setup for a binary classification
problem. Consult the tf.keras.metrics.*
and tfma.metrics.*
modules for
possible additional metrics supported.
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics { class_name: "ExampleCount" }
metrics { class_name: "WeightedExampleCount" }
metrics { class_name: "BinaryCrossentropy" }
metrics { class_name: "BinaryAccuracy" }
metrics { class_name: "AUC" }
metrics { class_name: "AUCPrecisionRecall" }
metrics { class_name: "MeanLabel" }
metrics { class_name: "MeanPrediction" }
metrics { class_name: "Calibration" }
metrics { class_name: "ConfusionMatrixPlot" }
metrics { class_name: "CalibrationPlot" }
}
""", tfma.EvalConfig()).metrics_specs
This same setup can be created using the following python code:
metrics = [
tfma.metrics.ExampleCount(name='example_count'),
tfma.metrics.WeightedExampleCount(name='weighted_example_count'),
tf.keras.metrics.BinaryCrossentropy(name='binary_crossentropy'),
tf.keras.metrics.BinaryAccuracy(name='accuracy'),
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
tf.keras.metrics.AUC(
name='auc_precision_recall', curve='PR', num_thresholds=10000),
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.Recall(name='recall'),
tfma.metrics.MeanLabel(name='mean_label'),
tfma.metrics.MeanPrediction(name='mean_prediction'),
tfma.metrics.Calibration(name='calibration'),
tfma.metrics.ConfusionMatrixPlot(name='confusion_matrix_plot'),
tfma.metrics.CalibrationPlot(name='calibration_plot')
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
Note that this setup is also avaliable by calling
tfma.metrics.default_binary_classification_specs
.
The following is an example configuration setup for a multi-class classification
problem. Consult the tf.keras.metrics.*
and tfma.metrics.*
modules for
possible additional metrics supported.
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics { class_name: "ExampleCount" }
metrics { class_name: "WeightedExampleCount" }
metrics { class_name: "SparseCategoricalCrossentropy" }
metrics { class_name: "SparseCategoricalAccuracy" }
metrics { class_name: "Precision" config: '"top_k": 1' }
metrics { class_name: "Precision" config: '"top_k": 3' }
metrics { class_name: "Recall" config: '"top_k": 1' }
metrics { class_name: "Recall" config: '"top_k": 3' }
}
""", tfma.EvalConfig()).metrics_specs
This same setup can be created using the following python code:
metrics = [
tfma.metrics.ExampleCount(name='example_count'),
tfma.metrics.WeightedExampleCount(name='weighted_example_count'),
tf.keras.metrics.SparseCategoricalCrossentropy(
name='sparse_categorical_crossentropy'),
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.Precision(name='precision', top_k=1),
tf.keras.metrics.Precision(name='precision', top_k=3),
tf.keras.metrics.Recall(name='recall', top_k=1),
tf.keras.metrics.Recall(name='recall', top_k=3),
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
Note that this setup is also avaliable by calling
tfma.metrics.default_multi_class_classification_specs
.
Multi-class/multi-label metrics can be binarized to produce metrics per class,
per top_k, etc using the tfma.BinarizationOptions
. For example:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
binarize: { class_ids: { values: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } }
// Metrics to binarize
metrics { class_name: "AUC" }
...
}
""", tfma.EvalConfig()).metrics_specs
This same setup can be created using the following python code:
metrics = [
// Metrics to binarize
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics, binarize=tfma.BinarizationOptions(
class_ids={'values': [0,1,2,3,4,5,6,7,8,9]}))
Multi-class/multi-label metrics can be aggregated to produce a single aggregated value for a binary classification metric.
Micro averaging can be performed either independently or as part of a
binarization of metrics by using the micro_average
option within
tfma.AggregationOptions
. For example:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
aggregeate: { micro_average: true }
// Metrics to aggregate
metrics { class_name: "AUC" }
...
}
""", tfma.EvalConfig()).metrics_specs
This same setup can be created using the following python code:
metrics = [
// Metrics to aggregate
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics, aggregate=tfma.AggregationOptions(micro_average=True))
Macro averaging must be performed as part of a binarization of metrics in
conjunctiopn with the maro_average
or weighted_macro_average
options within
tfma.AggregationOptions
. For example:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
binarize: { class_ids: { values: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } }
aggregeate: { macro_average: true }
// Metrics to both binarize and aggregate
metrics { class_name: "AUC" }
...
}
""", tfma.EvalConfig()).metrics_specs
This same setup can be created using the following python code:
metrics = [
// Metrics to both binarize and aggregate
tf.keras.metrics.AUC(name='auc', num_thresholds=10000),
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics,
binarize=tfma.BinarizationOptions(
class_ids={'values': [0,1,2,3,4,5,6,7,8,9]}),
aggregate=tfma.AggregationOptions(macro_average=True))
Query/ranking based metrics are enabled by specifying the query_key
option in
the metrics specs. For example:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
query_key: "doc_id"
binarize { top_k_list: { values: [1, 2] } }
metrics { class_name: "NDCG" config: '"gain_key": "gain"' }
}
metrics_specs {
query_key: "doc_id"
metrics { class_name: "MinLabelPosition" }
}
""", tfma.EvalConfig()).metrics_specs
This same setup can be created using the following python code:
metrics = [
tfma.metrics.NDCG(name='ndcg', gain_key='gain'),
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics, query_key='doc_id', binarize=tfma.BinarizationOptions(
top_k_list={'values': [1,2]}))
metrics = [
tfma.metrics.MinLabelPosition(name='min_label_position')
]
metrics_specs.extend(
tfma.metrics.specs_from_metrics(metrics, query_key='doc_id'))
TFMA supports evaluating multiple models at the same time. When multi-model
evaluation is performed, the names of the models associated with a set of
metrics must be specified in the model_names
section of the MetricsSpec. For
example:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
model_names: ["my-model1", "my-model2"]
...
}
""", tfma.EvalConfig()).metrics_specs
The specs_from_metrics
API also supports passing model names:
metrics = [
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics, model_names=['my-model1', 'my-model2'])
TFMA supports evaluating metrics on models that have different outputs.
Multi-output models store their output predictions in the form of a dict keyed
by output name. When multi-output model's are used, the names of the outputs
associated with a set of metrics must be specified in the output_names
section
of the MetricsSpec. For example:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
output_names: ["my-output"]
...
}
""", tfma.EvalConfig()).metrics_specs
The specs_from_metrics
API also supports passing output names:
metrics = [
...
]
metrics_specs = tfma.metrics.specs_from_metrics(
metrics, output_names=['my-output'])
TFMA allows customizing of the settings that are used with different metrics.
For example you might want to change the name, set thresholds, etc. This is done
by adding a config
section to the metric config. The config is specified using
the JSON string version of the parameters that would be passed to the metrics
__init__
method (for ease of use the leading and trailing '{' and '}' brackets
may be omitted). For example:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics {
class_name: "ConfusionMatrixAtThresholds"
config: '"thresholds": [0.3, 0.5, 0.8]'
}
}
""", tfma.MetricsSpec()).metrics_specs
This customization is of course also supported directly:
metrics = [
tfma.metrics.ConfusionMatrixAtThresholds(thresholds=[0.3, 0.5, 0.8]),
]
metrics_specs = tfma.metrics.specs_from_metrics(metrics)
NOTE: It is advisable to set the default number of thresholds used with AUC, etc to 10000 because this is the default value used by the underlying histogram calcuation which is shared between multiple metric implementations.
The output of a metric evaluation is a series of metric keys/values and/or plot keys/values based on the configuration used.
MetricKeys are defined using a structured key type. This key uniquely identifies each of the following aspects of a metric:
- Metric name (
auc
,mean_label
, etc) - Model name (only used if multi-model evaluation)
- Output name (only used if multi-output models are evaluated)
- Sub key (e.g. class ID if multi-class model is binarized)
MetricValues
are defined using a proto that encapulates the different value types supported
by the different metrics (e.g. double
, ConfusionMatrixAtThresholds
, etc).
PlotKeys are similar to metric keys except that for historical reasons all the plots values are stored in a single proto so the plot key does not have a name.
All the supported plots are stored in a single proto called PlotData.
The return from an evaluation run is an
EvalResult.
This record contains slicing_metrics
that encode the metric key as a
multi-level dict where the levels correspond to output name, class ID, metric
name, and metric value respectively. This is intended to be used for UI display
in a Jupiter notebook. If access to the underlying data is needed the metrics
result file should be used instead (see
metrics_for_slice.proto).
In addition to custom metrics that are added as part of a saved keras (or legacy EvalSavedModel). There are two ways to customize metrics in TFMA post saving: (1) by defining a custom keras metric class and (2) by defining a custom TFMA metrics class backed by a beam combiner.
In both cases, the metrics are configured by specifying the name of the metric class and associated module. For example:
from google.protobuf import text_format
metrics_specs = text_format.Parse("""
metrics_specs {
metrics { class_name: "MyMetric" module: "my.module"}
}
""", tfma.EvalConfig()).metrics_specs
NOTE: When customizing metrics you must ensure that the module is available to beam.
To create a custom keras metric, users need to extend tf.keras.metrics.Metric
with their implementation and then make sure the metric's module is available at
evaluation time.
Note that for metrics added post model save, TFMA only supports metrics that
take label (i.e. y_true), prediction (y_pred), and example weight
(sample_weight) as parameters to the update_state
method.
The following is an example of a custom keras metric:
class MyMetric(tf.keras.metrics.Mean):
def __init__(self, name='my_metric', dtype=None):
super(MyMetric, self).__init__(name=name, dtype=dtype)
def update_state(self, y_true, y_pred, sample_weight=None):
return super(MyMetric, self).update_state(
y_pred, sample_weight=sample_weight)
To create a custom TFMA metric, users need to extend tfma.metrics.Metric
with
their implementation and then make sure the metric's module is available at
evaluation time.
A tfma.metrics.Metric
implementation is made up of a set of kwargs that define
the metrics configuration along with a function for creating the computations
(possibly multiple) needed to calcuate the metrics value. There are two main
computation types that can be used: tfma.metrics.MetricComputation
and
tfma.metrics.DerivedMetricComputation
that are described in the sections
below. The function that creates these computations will be passed the following
parameters as input:
eval_config: tfam.EvalConfig
- The eval config passed to the evaluator (useful for looking up model spec settings such as prediction key to use, etc).
model_names: List[Text]
- List of model names to compute metrics for (None if single-model)
output_names: List[Text]
.- List of output names to compute metrics for (None if single-model)
sub_keys: List[tfma.SubKey]
.- List of sub keys (class ID, top K, etc) to compute metrics for (or None)
class_weights: Dict[int, float]
.- Class weights to use if computing an aggregation metric.
query_key: Text
- Query key used if computing a query/ranking based metric.
If a metric is not associated with one or more of these settings then it may leave those parameters out of its signature definition.
If a metric is computed the same way for each model, output, and sub key, then
the utility tfma.metrics.merge_per_key_computations
can be used to perform the
same computations for each of these inputs separately.
A MetricComputation
is made up of a combination of a preprocessor
and a
combiner
. The preprocessor
is a beam.DoFn
that takes extracts as its input
and outputs the initial state that will be used by the combiner (see
architecture for more info on what are extracts). If a
preprocessor
is not defined, then the combiner will be passed
StandardMetricInputs
(standard metric inputs contains labels, predictions, and example_weights). The
combiner
is a beam.CombineFn
that takes a tuple of (slice key, preprocessor
output) as its input and outputs a tuple of (slice_key, metric results dict) as
its result.
Note that slicing happens between the preprocessor
and combiner
.
Note that if a metric computation wants to make use of both the standard metric
inputs, but augment it with a few of the features from the features
extracts,
then the special
FeaturePreprocessor
can be used which will merge the requested features from multiple combiners into
a single shared StandardMetricsInputs value that is passed to all the combiners
(the combiners are responsible for reading the features they are interested in
and ignoring the rest).
The following is a very simple example of TFMA metric definition for computing the ExampleCount:
class ExampleCount(tfma.metrics.Metric):
def __init__(self, name: Text = 'example_count'):
super(ExampleCount, self).__init__(_example_count, name=name)
def _example_count(
name: Text = 'example_count') -> tfma.metrics.MetricComputations:
key = tfma.metrics.MetricKey(name=name)
return [
tfma.metrics.MetricComputation(
keys=[key],
preprocessor=_ExampleCountPreprocessor(),
combiner=_ExampleCountCombiner(key))
]
class _ExampleCountPreprocessor(beam.DoFn):
def process(self, extracts: tfma.Extracts) -> Iterable[int]:
yield 1
class _ExampleCountCombiner(beam.CombineFn):
def __init__(self, metric_key: tfma.metrics.MetricKey):
self._metric_key = metric_key
def create_accumulator(self) -> int:
return 0
def add_input(self, accumulator: int, state: int) -> int:
return accumulator + state
def merge_accumulators(self, accumulators: List[int]) -> int:
result = 0
for accumulator in accumulators:
result += accumulator
return result
def extract_output(self,
accumulator: int) -> Dict[tfma.metrics.MetricKey, int]:
return {self._metric_key: accumulator}
A DerivedMetricComputation
is made up of a result function that is used to
calculate metric values based on the output of other metric computations. The
result function takes a dict of computed values as its input and outputs a dict
of additional metric results.
Note that it is acceptable (recommended) to include the computations that a derived computation depends on in the list of computations created by a metric. This avoid having to pre-create and pass computations that are shared between multiple metrics. The evaluator will automatically de-dup computations that have the same definition so ony one computation is actually run.
The TJUR metrics provides a good example of derived metrics.