Skip to content

Commit 23c5fe8

Browse files
tf-model-analysis-teamtfx-copybara
authored andcommitted
Build a new metric to support confusion matrix plots in object detections.
PiperOrigin-RevId: 488845773
1 parent 9568ce3 commit 23c5fe8

7 files changed

+498
-210
lines changed

tensorflow_model_analysis/metrics/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@
9898
from tensorflow_model_analysis.metrics.multi_class_confusion_matrix_plot import MultiClassConfusionMatrixPlot
9999
from tensorflow_model_analysis.metrics.multi_label_confusion_matrix_plot import MultiLabelConfusionMatrixPlot
100100
from tensorflow_model_analysis.metrics.ndcg import NDCG
101+
from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionMaxRecall
102+
from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionPrecision
103+
from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionPrecisionAtRecall
104+
from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionRecall
105+
from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_metrics import ObjectDetectionThresholdAtRecall
106+
from tensorflow_model_analysis.metrics.object_detection_confusion_matrix_plot import ObjectDetectionConfusionMatrixPlot
101107
from tensorflow_model_analysis.metrics.object_detection_metrics import COCOAveragePrecision
102108
from tensorflow_model_analysis.metrics.object_detection_metrics import COCOAverageRecall
103109
from tensorflow_model_analysis.metrics.object_detection_metrics import COCOMeanAveragePrecision

tensorflow_model_analysis/metrics/confusion_matrix_plot.py

Lines changed: 69 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
"""Confusion matrix Plot."""
1515

16-
from typing import Any, Dict, Optional
16+
from typing import Any, Dict, Optional, List
1717

1818
from tensorflow_model_analysis.metrics import binary_confusion_matrices
1919
from tensorflow_model_analysis.metrics import metric_types
@@ -30,73 +30,83 @@ class ConfusionMatrixPlot(metric_types.Metric):
3030

3131
def __init__(self,
3232
num_thresholds: int = DEFAULT_NUM_THRESHOLDS,
33-
name: str = CONFUSION_MATRIX_PLOT_NAME):
33+
name: str = CONFUSION_MATRIX_PLOT_NAME,
34+
**kwargs):
3435
"""Initializes confusion matrix plot.
3536
3637
Args:
3738
num_thresholds: Number of thresholds to use when discretizing the curve.
3839
Values must be > 1. Defaults to 1000.
3940
name: Metric name.
41+
**kwargs: (Optional) Additional args to pass along to init (and eventually
42+
on to _confusion_matrix_plot). These kwargs are useful for subclasses to
43+
pass information from their init to the create_computation_fn.
4044
"""
4145
super().__init__(
42-
metric_util.merge_per_key_computations(_confusion_matrix_plot),
46+
metric_util.merge_per_key_computations(self._confusion_matrix_plot),
4347
num_thresholds=num_thresholds,
44-
name=name)
48+
name=name,
49+
**kwargs)
50+
51+
def _confusion_matrix_plot(
52+
self,
53+
num_thresholds: int = DEFAULT_NUM_THRESHOLDS,
54+
name: str = CONFUSION_MATRIX_PLOT_NAME,
55+
eval_config: Optional[config_pb2.EvalConfig] = None,
56+
model_name: str = '',
57+
output_name: str = '',
58+
sub_key: Optional[metric_types.SubKey] = None,
59+
aggregation_type: Optional[metric_types.AggregationType] = None,
60+
class_weights: Optional[Dict[int, float]] = None,
61+
example_weighted: bool = False,
62+
preprocessors: Optional[List[metric_types.Preprocessor]] = None,
63+
) -> metric_types.MetricComputations:
64+
"""Returns metric computations for confusion matrix plots."""
65+
key = metric_types.PlotKey(
66+
name=name,
67+
model_name=model_name,
68+
output_name=output_name,
69+
sub_key=sub_key,
70+
example_weighted=example_weighted)
71+
72+
# The interoploation strategy used here matches how the legacy post export
73+
# metrics calculated its plots.
74+
thresholds = [
75+
i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1)
76+
]
77+
thresholds = [-1e-6] + thresholds
78+
79+
# Make sure matrices are calculated.
80+
matrices_computations = binary_confusion_matrices.binary_confusion_matrices(
81+
# Use a custom name since we have a custom interpolation strategy which
82+
# will cause the default naming used by the binary confusion matrix to
83+
# be very long.
84+
name=(binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME + '_' +
85+
name),
86+
eval_config=eval_config,
87+
model_name=model_name,
88+
output_name=output_name,
89+
sub_key=sub_key,
90+
aggregation_type=aggregation_type,
91+
class_weights=class_weights,
92+
example_weighted=example_weighted,
93+
thresholds=thresholds,
94+
use_histogram=True,
95+
preprocessors=preprocessors)
96+
matrices_key = matrices_computations[-1].keys[-1]
97+
98+
def result(
99+
metrics: Dict[metric_types.MetricKey, Any]
100+
) -> Dict[metric_types.MetricKey, binary_confusion_matrices.Matrices]:
101+
return {
102+
key: metrics[matrices_key].to_proto().confusion_matrix_at_thresholds
103+
}
104+
105+
derived_computation = metric_types.DerivedMetricComputation(
106+
keys=[key], result=result)
107+
computations = matrices_computations
108+
computations.append(derived_computation)
109+
return computations
45110

46111

47112
metric_types.register_metric(ConfusionMatrixPlot)
48-
49-
50-
def _confusion_matrix_plot(
51-
num_thresholds: int = DEFAULT_NUM_THRESHOLDS,
52-
name: str = CONFUSION_MATRIX_PLOT_NAME,
53-
eval_config: Optional[config_pb2.EvalConfig] = None,
54-
model_name: str = '',
55-
output_name: str = '',
56-
sub_key: Optional[metric_types.SubKey] = None,
57-
aggregation_type: Optional[metric_types.AggregationType] = None,
58-
class_weights: Optional[Dict[int, float]] = None,
59-
example_weighted: bool = False) -> metric_types.MetricComputations:
60-
"""Returns metric computations for confusion matrix plots."""
61-
key = metric_types.PlotKey(
62-
name=name,
63-
model_name=model_name,
64-
output_name=output_name,
65-
sub_key=sub_key,
66-
example_weighted=example_weighted)
67-
68-
# The interoploation strategy used here matches how the legacy post export
69-
# metrics calculated its plots.
70-
thresholds = [i * 1.0 / num_thresholds for i in range(0, num_thresholds + 1)]
71-
thresholds = [-1e-6] + thresholds
72-
73-
# Make sure matrices are calculated.
74-
matrices_computations = binary_confusion_matrices.binary_confusion_matrices(
75-
# Use a custom name since we have a custom interpolation strategy which
76-
# will cause the default naming used by the binary confusion matrix to be
77-
# very long.
78-
name=(binary_confusion_matrices.BINARY_CONFUSION_MATRICES_NAME + '_' +
79-
name),
80-
eval_config=eval_config,
81-
model_name=model_name,
82-
output_name=output_name,
83-
sub_key=sub_key,
84-
aggregation_type=aggregation_type,
85-
class_weights=class_weights,
86-
example_weighted=example_weighted,
87-
thresholds=thresholds,
88-
use_histogram=True)
89-
matrices_key = matrices_computations[-1].keys[-1]
90-
91-
def result(
92-
metrics: Dict[metric_types.MetricKey, Any]
93-
) -> Dict[metric_types.MetricKey, binary_confusion_matrices.Matrices]:
94-
return {
95-
key: metrics[matrices_key].to_proto().confusion_matrix_at_thresholds
96-
}
97-
98-
derived_computation = metric_types.DerivedMetricComputation(
99-
keys=[key], result=result)
100-
computations = matrices_computations
101-
computations.append(derived_computation)
102-
return computations

tensorflow_model_analysis/metrics/metric_util.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,41 @@
4040
_EPSILON = 1e-7
4141

4242

43+
def validate_object_detection_arguments(
44+
class_id: Optional[Union[int, List[int]]],
45+
class_weight: Optional[Union[float, List[float]]],
46+
area_range: Optional[Tuple[float, float]] = None,
47+
max_num_detections: Optional[int] = None,
48+
labels_to_stack: Optional[List[str]] = None,
49+
predictions_to_stack: Optional[List[str]] = None,
50+
output_name: Optional[str] = None) -> None:
51+
"""Validate the arguments for object detection related functions."""
52+
if class_id is None:
53+
raise ValueError('class_id must be provided if use object' ' detection.')
54+
if isinstance(class_id, int):
55+
class_id = [class_id]
56+
if class_weight is not None:
57+
if isinstance(class_weight, float):
58+
class_weight = [class_weight]
59+
for weight in class_weight:
60+
if weight < 0:
61+
raise ValueError(f'class_weight = {class_weight} must '
62+
'not be negative.')
63+
if len(class_id) != len(class_weight):
64+
raise ValueError('Mismatch of length between class_id = '
65+
f'{class_id} and class_weight = '
66+
f'{class_weight}.')
67+
if area_range is not None:
68+
if len(area_range) != 2 or area_range[0] > area_range[1]:
69+
raise ValueError(f'area_range = {area_range} must be a valid interval.')
70+
if max_num_detections is not None and max_num_detections <= 0:
71+
raise ValueError(f'max_num_detections = {max_num_detections} must be '
72+
'positive.')
73+
if output_name and (labels_to_stack or predictions_to_stack):
74+
raise ValueError('The metric does not support specifying the output name'
75+
' when there are keys/outputs specified to be stacked.')
76+
77+
4378
def generate_private_name_from_arguments(name: str, **kwargs) -> str:
4479
"""Generate names for used metrics.
4580

0 commit comments

Comments
 (0)