13
13
# limitations under the License.
14
14
"""Confusion matrix Plot."""
15
15
16
- from typing import Any , Dict , Optional
16
+ from typing import Any , Dict , Optional , List
17
17
18
18
from tensorflow_model_analysis .metrics import binary_confusion_matrices
19
19
from tensorflow_model_analysis .metrics import metric_types
@@ -30,73 +30,83 @@ class ConfusionMatrixPlot(metric_types.Metric):
30
30
31
31
def __init__ (self ,
32
32
num_thresholds : int = DEFAULT_NUM_THRESHOLDS ,
33
- name : str = CONFUSION_MATRIX_PLOT_NAME ):
33
+ name : str = CONFUSION_MATRIX_PLOT_NAME ,
34
+ ** kwargs ):
34
35
"""Initializes confusion matrix plot.
35
36
36
37
Args:
37
38
num_thresholds: Number of thresholds to use when discretizing the curve.
38
39
Values must be > 1. Defaults to 1000.
39
40
name: Metric name.
41
+ **kwargs: (Optional) Additional args to pass along to init (and eventually
42
+ on to _confusion_matrix_plot). These kwargs are useful for subclasses to
43
+ pass information from their init to the create_computation_fn.
40
44
"""
41
45
super ().__init__ (
42
- metric_util .merge_per_key_computations (_confusion_matrix_plot ),
46
+ metric_util .merge_per_key_computations (self . _confusion_matrix_plot ),
43
47
num_thresholds = num_thresholds ,
44
- name = name )
48
+ name = name ,
49
+ ** kwargs )
50
+
51
+ def _confusion_matrix_plot (
52
+ self ,
53
+ num_thresholds : int = DEFAULT_NUM_THRESHOLDS ,
54
+ name : str = CONFUSION_MATRIX_PLOT_NAME ,
55
+ eval_config : Optional [config_pb2 .EvalConfig ] = None ,
56
+ model_name : str = '' ,
57
+ output_name : str = '' ,
58
+ sub_key : Optional [metric_types .SubKey ] = None ,
59
+ aggregation_type : Optional [metric_types .AggregationType ] = None ,
60
+ class_weights : Optional [Dict [int , float ]] = None ,
61
+ example_weighted : bool = False ,
62
+ preprocessors : Optional [List [metric_types .Preprocessor ]] = None ,
63
+ ) -> metric_types .MetricComputations :
64
+ """Returns metric computations for confusion matrix plots."""
65
+ key = metric_types .PlotKey (
66
+ name = name ,
67
+ model_name = model_name ,
68
+ output_name = output_name ,
69
+ sub_key = sub_key ,
70
+ example_weighted = example_weighted )
71
+
72
+ # The interoploation strategy used here matches how the legacy post export
73
+ # metrics calculated its plots.
74
+ thresholds = [
75
+ i * 1.0 / num_thresholds for i in range (0 , num_thresholds + 1 )
76
+ ]
77
+ thresholds = [- 1e-6 ] + thresholds
78
+
79
+ # Make sure matrices are calculated.
80
+ matrices_computations = binary_confusion_matrices .binary_confusion_matrices (
81
+ # Use a custom name since we have a custom interpolation strategy which
82
+ # will cause the default naming used by the binary confusion matrix to
83
+ # be very long.
84
+ name = (binary_confusion_matrices .BINARY_CONFUSION_MATRICES_NAME + '_' +
85
+ name ),
86
+ eval_config = eval_config ,
87
+ model_name = model_name ,
88
+ output_name = output_name ,
89
+ sub_key = sub_key ,
90
+ aggregation_type = aggregation_type ,
91
+ class_weights = class_weights ,
92
+ example_weighted = example_weighted ,
93
+ thresholds = thresholds ,
94
+ use_histogram = True ,
95
+ preprocessors = preprocessors )
96
+ matrices_key = matrices_computations [- 1 ].keys [- 1 ]
97
+
98
+ def result (
99
+ metrics : Dict [metric_types .MetricKey , Any ]
100
+ ) -> Dict [metric_types .MetricKey , binary_confusion_matrices .Matrices ]:
101
+ return {
102
+ key : metrics [matrices_key ].to_proto ().confusion_matrix_at_thresholds
103
+ }
104
+
105
+ derived_computation = metric_types .DerivedMetricComputation (
106
+ keys = [key ], result = result )
107
+ computations = matrices_computations
108
+ computations .append (derived_computation )
109
+ return computations
45
110
46
111
47
112
metric_types .register_metric (ConfusionMatrixPlot )
48
-
49
-
50
- def _confusion_matrix_plot (
51
- num_thresholds : int = DEFAULT_NUM_THRESHOLDS ,
52
- name : str = CONFUSION_MATRIX_PLOT_NAME ,
53
- eval_config : Optional [config_pb2 .EvalConfig ] = None ,
54
- model_name : str = '' ,
55
- output_name : str = '' ,
56
- sub_key : Optional [metric_types .SubKey ] = None ,
57
- aggregation_type : Optional [metric_types .AggregationType ] = None ,
58
- class_weights : Optional [Dict [int , float ]] = None ,
59
- example_weighted : bool = False ) -> metric_types .MetricComputations :
60
- """Returns metric computations for confusion matrix plots."""
61
- key = metric_types .PlotKey (
62
- name = name ,
63
- model_name = model_name ,
64
- output_name = output_name ,
65
- sub_key = sub_key ,
66
- example_weighted = example_weighted )
67
-
68
- # The interoploation strategy used here matches how the legacy post export
69
- # metrics calculated its plots.
70
- thresholds = [i * 1.0 / num_thresholds for i in range (0 , num_thresholds + 1 )]
71
- thresholds = [- 1e-6 ] + thresholds
72
-
73
- # Make sure matrices are calculated.
74
- matrices_computations = binary_confusion_matrices .binary_confusion_matrices (
75
- # Use a custom name since we have a custom interpolation strategy which
76
- # will cause the default naming used by the binary confusion matrix to be
77
- # very long.
78
- name = (binary_confusion_matrices .BINARY_CONFUSION_MATRICES_NAME + '_' +
79
- name ),
80
- eval_config = eval_config ,
81
- model_name = model_name ,
82
- output_name = output_name ,
83
- sub_key = sub_key ,
84
- aggregation_type = aggregation_type ,
85
- class_weights = class_weights ,
86
- example_weighted = example_weighted ,
87
- thresholds = thresholds ,
88
- use_histogram = True )
89
- matrices_key = matrices_computations [- 1 ].keys [- 1 ]
90
-
91
- def result (
92
- metrics : Dict [metric_types .MetricKey , Any ]
93
- ) -> Dict [metric_types .MetricKey , binary_confusion_matrices .Matrices ]:
94
- return {
95
- key : metrics [matrices_key ].to_proto ().confusion_matrix_at_thresholds
96
- }
97
-
98
- derived_computation = metric_types .DerivedMetricComputation (
99
- keys = [key ], result = result )
100
- computations = matrices_computations
101
- computations .append (derived_computation )
102
- return computations
0 commit comments