Skip to content

Commit 8471f21

Browse files
tf-model-analysis-teamtfx-copybara
tf-model-analysis-team
authored andcommitted
Implements ModelCosineSimilarity metric class, which compares candidate model(s) predictions against baseline model predictions using cosine similarity.
PiperOrigin-RevId: 643161747
1 parent abec4cd commit 8471f21

File tree

3 files changed

+336
-0
lines changed

3 files changed

+336
-0
lines changed

RELEASE.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
[py-ml-metrics](https://pypi.org/project/py-ml-metrics/) package.
1919
* Adds Constituent Flip Rate Metrics: SymmetricFlipRate, NegToNegFlipRate,
2020
NegToPosFlipRate, PosToNegFlipRate, PosToPosFlipRate.
21+
* Adds Model Cosine Similarity Metrics.
2122
* Depend on tensorflow-estimator package explicitly.
2223

2324
## Bug fixes and other Changes
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Model cosine similiarty metrics."""
15+
16+
from collections.abc import Iterable
17+
import dataclasses
18+
from typing import Any, Optional
19+
20+
import apache_beam as beam
21+
import numpy as np
22+
from tensorflow_model_analysis import metrics
23+
from tensorflow_model_analysis.metrics import metric_types
24+
from tensorflow_model_analysis.metrics import metric_util
25+
from tensorflow_model_analysis.proto import config_pb2
26+
from tensorflow_model_analysis.utils import model_util
27+
28+
_COSINE_SIMILARITY_METRIC_NAME = 'model_cosine_similarity'
29+
30+
31+
def _compute_cosine_similarity(
32+
baseline_prediction: np.ndarray[Any, Any],
33+
candidate_prediction: np.ndarray[Any, Any],
34+
) -> float:
35+
"""Computes cosine similarity between two predictions of np.ndarrays."""
36+
return np.dot(baseline_prediction, candidate_prediction) / (
37+
np.linalg.norm(baseline_prediction) * np.linalg.norm(candidate_prediction)
38+
)
39+
40+
41+
@dataclasses.dataclass
42+
class _CosineSimilarityAccumulator:
43+
"""Accumulator for computing average CosineSimilarity."""
44+
45+
num_examples: int = 0
46+
sum_cosine_similarity: float = 0.0
47+
48+
def merge(self, other: '_CosineSimilarityAccumulator'):
49+
self.num_examples += other.num_examples
50+
self.sum_cosine_similarity += other.sum_cosine_similarity
51+
52+
def get_average(self) -> float:
53+
if self.num_examples == 0:
54+
return np.nan
55+
return self.sum_cosine_similarity / self.num_examples
56+
57+
58+
class ModelCosineSimilarity(metrics.Metric):
59+
"""ModelCosineSimilarity compares predictions from baseline and candidate models using cosine similarity."""
60+
61+
def __init__(self, name: str = _COSINE_SIMILARITY_METRIC_NAME):
62+
super().__init__(self._metric_computation, name=name)
63+
64+
def _metric_computation(
65+
self,
66+
name: str,
67+
eval_config: config_pb2.EvalConfig,
68+
model_names: Iterable[str],
69+
output_names: Optional[Iterable[str]] = ('',),
70+
sub_keys: Optional[Iterable[metric_types.SubKey]] = None,
71+
) -> metrics.MetricComputations:
72+
"""Returns the metric computations for calculating the cosine similarity.
73+
74+
Args:
75+
name: Metric name for individual flip rate.
76+
eval_config: The EvalConfig for this TFMA evaluation. This is used to
77+
identify which model is the baseline.
78+
model_names: The name of the baseline model and the candidate model.
79+
output_names: The set of output names for which to compute this metric.
80+
sub_keys: The set of sub_key settings for which to compute this metric.
81+
"""
82+
computations = []
83+
84+
# Get the baseline model name.
85+
baseline_spec = model_util.get_baseline_model_spec(eval_config)
86+
baseline_model_name = baseline_spec.name if baseline_spec else None
87+
88+
for candidate_model_name in model_names:
89+
if candidate_model_name == baseline_model_name:
90+
continue
91+
for output_name in output_names:
92+
for sub_key in sub_keys or (None,):
93+
# Define the metric key.
94+
key = metric_types.MetricKey(
95+
name=name,
96+
model_name=candidate_model_name,
97+
output_name=output_name,
98+
sub_key=sub_key,
99+
is_diff=True,
100+
)
101+
102+
# Append cosine similarity calculation to computations.
103+
computations.append(
104+
metrics.MetricComputation(
105+
keys=[key],
106+
preprocessors=None,
107+
combiner=_ModelCosineSimilarityCombiner(
108+
metric_key=key,
109+
eval_config=eval_config,
110+
baseline_model_name=baseline_model_name,
111+
model_name=candidate_model_name,
112+
output_name=output_name,
113+
),
114+
)
115+
)
116+
117+
return computations
118+
119+
120+
class _ModelCosineSimilarityCombiner(beam.CombineFn):
121+
"""A combiner for computing the cosine similarity between models."""
122+
123+
def __init__(
124+
self,
125+
metric_key: metrics.MetricKey,
126+
eval_config: config_pb2.EvalConfig,
127+
baseline_model_name: str,
128+
model_name: str,
129+
output_name: str,
130+
):
131+
self._metric_key = metric_key
132+
self._eval_config = eval_config
133+
self._baseline_model_name = baseline_model_name
134+
self._model_name = model_name
135+
self._output_name = output_name
136+
137+
def create_accumulator(self) -> _CosineSimilarityAccumulator:
138+
return _CosineSimilarityAccumulator()
139+
140+
def add_input(
141+
self,
142+
accumulator: _CosineSimilarityAccumulator,
143+
element: metric_types.StandardMetricInputs,
144+
) -> _CosineSimilarityAccumulator:
145+
_, baseline_prediction, _ = next(
146+
metric_util.to_label_prediction_example_weight(
147+
inputs=element,
148+
eval_config=self._eval_config,
149+
model_name=self._baseline_model_name,
150+
output_name=self._output_name,
151+
flatten=False,
152+
allow_none=False,
153+
)
154+
)
155+
156+
_, candidate_prediction, _ = next(
157+
metric_util.to_label_prediction_example_weight(
158+
inputs=element,
159+
eval_config=self._eval_config,
160+
model_name=self._model_name,
161+
output_name=self._output_name,
162+
flatten=False,
163+
allow_none=False,
164+
)
165+
)
166+
accumulator.merge(
167+
_CosineSimilarityAccumulator(
168+
num_examples=1,
169+
sum_cosine_similarity=_compute_cosine_similarity(
170+
baseline_prediction, candidate_prediction
171+
),
172+
)
173+
)
174+
175+
return accumulator
176+
177+
def merge_accumulators(
178+
self, accumulators: Iterable[_CosineSimilarityAccumulator]
179+
) -> _CosineSimilarityAccumulator:
180+
result = next(iter(accumulators))
181+
for accumulator in accumulators:
182+
result.merge(accumulator)
183+
return result
184+
185+
def extract_output(
186+
self, accumulator: _CosineSimilarityAccumulator
187+
) -> dict[metrics.MetricKey, float]:
188+
return {self._metric_key: accumulator.get_average()}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for model cosine similiarty metrics."""
15+
16+
from absl.testing import absltest
17+
from absl.testing import parameterized
18+
import apache_beam as beam
19+
from apache_beam.testing import util
20+
import numpy as np
21+
from tensorflow_model_analysis import constants
22+
from tensorflow_model_analysis.metrics import metric_types
23+
from tensorflow_model_analysis.metrics import metric_util
24+
from tensorflow_model_analysis.metrics import model_cosine_similarity
25+
from tensorflow_model_analysis.proto import config_pb2
26+
27+
from google.protobuf import text_format
28+
29+
_PREDICTION_A = np.array([1.0, 0.5, 0.5, 1.0])
30+
_PREDICTION_B = np.array([0.5, 1.0, 1.0, 0.5])
31+
_PREDICTION_C = np.array([0.25, 0.1, 0.9, 0.75])
32+
33+
34+
class ModelCosineSimilarityMetricsTest(parameterized.TestCase):
35+
36+
@parameterized.named_parameters(
37+
dict(
38+
testcase_name='no_change',
39+
prediction_pairs=[
40+
(_PREDICTION_A, _PREDICTION_A),
41+
(_PREDICTION_B, _PREDICTION_B),
42+
(_PREDICTION_C, _PREDICTION_C),
43+
],
44+
# cs(p1, p2):
45+
# np.dot(p1, p2) / (np.linalg.norm(p1) * np.linalg.norm(p2))
46+
# cs(_PREDICTION_A/B/C, _PREDICTION_A/B/C) = 1.0
47+
expected_average_cosine_similarity=1.0,
48+
),
49+
dict(
50+
testcase_name='small_change',
51+
prediction_pairs=[
52+
(_PREDICTION_A, _PREDICTION_A),
53+
(_PREDICTION_B, _PREDICTION_A),
54+
(_PREDICTION_A, _PREDICTION_B),
55+
],
56+
# cs(_PREDICTION_A, _PREDICTION_A) = 1.0
57+
# cs(_PREDICTION_B, _PREDICTION_A) = 0.8
58+
# cs(_PREDICTION_A, _PREDICTION_B) = 0.8
59+
expected_average_cosine_similarity=0.8666666666666666,
60+
),
61+
dict(
62+
testcase_name='large_change',
63+
prediction_pairs=[
64+
(_PREDICTION_C, _PREDICTION_A),
65+
(_PREDICTION_A, _PREDICTION_B),
66+
(_PREDICTION_B, _PREDICTION_C),
67+
],
68+
# cs(_PREDICTION_C, _PREDICTION_A) = 0.7892004626469845
69+
# cs(_PREDICTION_A, _PREDICTION_B) = 0.8
70+
# cs(_PREDICTION_B, _PREDICTION_C) = 0.7892004626469845
71+
expected_average_cosine_similarity=0.7928003084313229,
72+
),
73+
)
74+
def test_cosine_similarity(
75+
self, prediction_pairs, expected_average_cosine_similarity
76+
):
77+
baseline_model_name = 'baseline'
78+
candidate_model_name = 'candidate'
79+
80+
eval_config = text_format.Parse(
81+
"""
82+
model_specs {
83+
name: "baseline"
84+
is_baseline: true
85+
}
86+
model_specs {
87+
name: "candidate"
88+
}
89+
""",
90+
config_pb2.EvalConfig(),
91+
)
92+
93+
computations = model_cosine_similarity.ModelCosineSimilarity().computations(
94+
eval_config=eval_config,
95+
model_names=['baseline', 'candidate'],
96+
output_names=[''],
97+
)
98+
self.assertLen(computations, 1)
99+
cosine_similarity = computations[0]
100+
101+
examples = []
102+
for baseline_prediction, candidate_prediction in prediction_pairs:
103+
examples.append({
104+
constants.LABELS_KEY: [0],
105+
constants.PREDICTIONS_KEY: {
106+
baseline_model_name: baseline_prediction,
107+
candidate_model_name: candidate_prediction,
108+
},
109+
})
110+
111+
with beam.Pipeline() as pipeline:
112+
result = (
113+
pipeline
114+
| 'Create' >> beam.Create(examples)
115+
| 'Process' >> beam.Map(metric_util.to_standard_metric_inputs)
116+
| 'AddSlice' >> beam.Map(lambda x: ((), x))
117+
| 'ComputeMetric' >> beam.CombinePerKey(cosine_similarity.combiner)
118+
)
119+
120+
def check_result(got):
121+
try:
122+
self.assertLen(got, 1)
123+
got_slice_key, got_metrics = got[0]
124+
self.assertEqual(got_slice_key, ())
125+
126+
metric_key = metric_types.MetricKey(
127+
name=model_cosine_similarity._COSINE_SIMILARITY_METRIC_NAME,
128+
model_name=candidate_model_name,
129+
output_name='',
130+
is_diff=True,
131+
)
132+
133+
self.assertIn(metric_key, got_metrics)
134+
self.assertIsInstance(got_metrics[metric_key], float)
135+
self.assertAlmostEqual(
136+
got_metrics[metric_key],
137+
expected_average_cosine_similarity,
138+
)
139+
140+
except AssertionError as err:
141+
raise util.BeamAssertException(err)
142+
143+
util.assert_that(result, check_result, label='result')
144+
145+
146+
if __name__ == '__main__':
147+
absltest.main()

0 commit comments

Comments
 (0)