Skip to content

Commit db2c434

Browse files
authored
qa uses base classes and is testable (#38993)
* qa uses base classes and is testable * evaluator list input for base class
1 parent c694e3f commit db2c434

File tree

6 files changed

+131
-100
lines changed

6 files changed

+131
-100
lines changed

sdk/evaluation/azure-ai-evaluation/assets.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/evaluation/azure-ai-evaluation",
5-
"Tag": "python/evaluation/azure-ai-evaluation_326efc986d"
5+
"Tag": "python/evaluation/azure-ai-evaluation_23e89ff5ac"
66
}

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from ._base_eval import EvaluatorBase
66
from ._base_prompty_eval import PromptyEvaluatorBase
77
from ._base_rai_svc_eval import RaiServiceEvaluatorBase
8+
from ._base_multi_eval import MultiEvaluatorBase
89

910
__all__ = [
1011
"EvaluatorBase",
1112
"PromptyEvaluatorBase",
1213
"RaiServiceEvaluatorBase",
14+
"MultiEvaluatorBase",
1315
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
from concurrent.futures import as_completed
5+
from typing import TypeVar, Dict, List
6+
7+
from promptflow.tracing import ThreadPoolExecutorWithContext as ThreadPoolExecutor
8+
from typing_extensions import override
9+
10+
from azure.ai.evaluation._evaluators._common import EvaluatorBase
11+
12+
T = TypeVar("T")
13+
14+
15+
class MultiEvaluatorBase(EvaluatorBase[T]):
16+
"""
17+
Base class for evaluators that contain and run multiple other evaluators to produce a
18+
suite of metrics.
19+
20+
Child classes still need to implement the __call__ methods, but they shouldn't need a _do_eval.
21+
22+
:param evaluators: The list of evaluators to run when this evaluator is called.
23+
:type evaluators: List[~azure.ai.evaluation._evaluators._common.EvaluatorBase]
24+
:param kwargs: Additional arguments to pass to the evaluator.
25+
:type kwargs: Any
26+
:return: An evaluator that runs multiple other evaluators and combines their results.
27+
"""
28+
29+
def __init__(self, evaluators: List[EvaluatorBase[T]], **kwargs):
30+
super().__init__()
31+
self._parallel = kwargs.pop("_parallel", True)
32+
self._evaluators = evaluators
33+
34+
@override
35+
async def _do_eval(self, eval_input: Dict) -> Dict[str, T]:
36+
"""Run each evaluator, possibly in parallel, and combine the results into
37+
a single large dictionary containing each evaluation. Inputs are passed
38+
directly to each evaluator without additional processing.
39+
40+
41+
:param eval_input: The input to the evaluation function.
42+
:type eval_input: Dict
43+
:return: The evaluation result.
44+
:rtype: Dict
45+
"""
46+
results: Dict[str, T] = {}
47+
if self._parallel:
48+
with ThreadPoolExecutor() as executor:
49+
# pylint: disable=no-value-for-parameter
50+
futures = {executor.submit(evaluator, **eval_input): evaluator for evaluator in self._evaluators}
51+
52+
for future in as_completed(futures):
53+
results.update(future.result())
54+
else:
55+
for evaluator in self._evaluators:
56+
result = evaluator(**eval_input)
57+
# Ignore is to avoid mypy getting upset over the amount of duck-typing
58+
# that's going on to shove evaluators around like this.
59+
results.update(result) # type: ignore[arg-type]
60+
61+
return results

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py

+5-42
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4-
from concurrent.futures import as_completed
5-
from typing import Callable, Dict, List, Union
4+
from typing import Dict, List, Union
65

7-
from promptflow.tracing import ThreadPoolExecutorWithContext as ThreadPoolExecutor
86
from typing_extensions import overload, override
97

10-
from azure.ai.evaluation._evaluators._common import EvaluatorBase
8+
from azure.ai.evaluation._evaluators._common import MultiEvaluatorBase
119
from azure.ai.evaluation._model_configurations import Conversation
1210
from azure.ai.evaluation._common._experimental import experimental
1311

@@ -18,7 +16,7 @@
1816

1917

2018
@experimental
21-
class ContentSafetyEvaluator(EvaluatorBase[Union[str, float]]):
19+
class ContentSafetyEvaluator(MultiEvaluatorBase[Union[str, float]]):
2220
"""
2321
Initialize a content safety evaluator configured to evaluate content safety metrics for QA scenario.
2422
@@ -44,16 +42,14 @@ class ContentSafetyEvaluator(EvaluatorBase[Union[str, float]]):
4442
id = "content_safety"
4543
"""Evaluator identifier, experimental and to be used only with evaluation in cloud."""
4644

47-
# TODO address 3579092 to re-enabled parallel evals.
4845
def __init__(self, credential, azure_ai_project, **kwargs):
49-
super().__init__()
50-
self._parallel = kwargs.pop("_parallel", True)
51-
self._evaluators: List[Callable[..., Dict[str, Union[str, float]]]] = [
46+
evaluators = [
5247
ViolenceEvaluator(credential, azure_ai_project),
5348
SexualEvaluator(credential, azure_ai_project),
5449
SelfHarmEvaluator(credential, azure_ai_project),
5550
HateUnfairnessEvaluator(credential, azure_ai_project),
5651
]
52+
super().__init__(evaluators=evaluators, **kwargs)
5753

5854
@overload
5955
def __call__(
@@ -109,36 +105,3 @@ def __call__( # pylint: disable=docstring-missing-param
109105
:rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]]
110106
"""
111107
return super().__call__(*args, **kwargs)
112-
113-
@override
114-
async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[str, float]]:
115-
"""Perform the evaluation using the Azure AI RAI service.
116-
The exact evaluation performed is determined by the evaluation metric supplied
117-
by the child class initializer.
118-
119-
:param eval_input: The input to the evaluation function.
120-
:type eval_input: Dict
121-
:return: The evaluation result.
122-
:rtype: Dict
123-
"""
124-
query = eval_input.get("query", None)
125-
response = eval_input.get("response", None)
126-
conversation = eval_input.get("conversation", None)
127-
results: Dict[str, Union[str, float]] = {}
128-
# TODO fix this to not explode on empty optional inputs (PF SKD error)
129-
if self._parallel:
130-
with ThreadPoolExecutor() as executor:
131-
# pylint: disable=no-value-for-parameter
132-
futures = {
133-
executor.submit(evaluator, query=query, response=response, conversation=conversation): evaluator
134-
for evaluator in self._evaluators
135-
}
136-
137-
for future in as_completed(futures):
138-
results.update(future.result())
139-
else:
140-
for evaluator in self._evaluators:
141-
result = evaluator(query=query, response=response, conversation=conversation)
142-
results.update(result)
143-
144-
return results

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_qa/_qa.py

+32-27
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44

5-
from concurrent.futures import as_completed
6-
from typing import Callable, Dict, List, Union
5+
from typing import Union
76

8-
from promptflow.tracing import ThreadPoolExecutorWithContext as ThreadPoolExecutor
7+
from typing_extensions import overload, override
8+
9+
from azure.ai.evaluation._evaluators._common import MultiEvaluatorBase
910

1011
from .._coherence import CoherenceEvaluator
1112
from .._f1_score import F1ScoreEvaluator
@@ -15,7 +16,7 @@
1516
from .._similarity import SimilarityEvaluator
1617

1718

18-
class QAEvaluator:
19+
class QAEvaluator(MultiEvaluatorBase[Union[str, float]]):
1920
"""
2021
Initialize a question-answer evaluator configured for a specific Azure OpenAI model.
2122
@@ -46,18 +47,39 @@ class QAEvaluator:
4647
"""Evaluator identifier, experimental and to be used only with evaluation in cloud."""
4748

4849
def __init__(self, model_config, **kwargs):
49-
self._parallel = kwargs.pop("_parallel", False)
50-
51-
self._evaluators: List[Union[Callable[..., Dict[str, Union[str, float]]], Callable[..., Dict[str, float]]]] = [
50+
evaluators = [
5251
GroundednessEvaluator(model_config),
5352
RelevanceEvaluator(model_config),
5453
CoherenceEvaluator(model_config),
5554
FluencyEvaluator(model_config),
5655
SimilarityEvaluator(model_config),
5756
F1ScoreEvaluator(),
5857
]
58+
super().__init__(evaluators=evaluators, **kwargs)
59+
60+
@overload # type: ignore
61+
def __call__(self, *, query: str, response: str, context: str, ground_truth: str):
62+
"""
63+
Evaluates question-answering scenario.
64+
65+
:keyword query: The query to be evaluated.
66+
:paramtype query: str
67+
:keyword response: The response to be evaluated.
68+
:paramtype response: str
69+
:keyword context: The context to be evaluated.
70+
:paramtype context: str
71+
:keyword ground_truth: The ground truth to be evaluated.
72+
:paramtype ground_truth: str
73+
:return: The scores for QA scenario.
74+
:rtype: Dict[str, Union[str, float]]
75+
"""
5976

60-
def __call__(self, *, query: str, response: str, context: str, ground_truth: str, **kwargs):
77+
@override
78+
def __call__( # pylint: disable=docstring-missing-param
79+
self,
80+
*args,
81+
**kwargs,
82+
):
6183
"""
6284
Evaluates question-answering scenario.
6385
@@ -72,22 +94,5 @@ def __call__(self, *, query: str, response: str, context: str, ground_truth: str
7294
:return: The scores for QA scenario.
7395
:rtype: Dict[str, Union[str, float]]
7496
"""
75-
results: Dict[str, Union[str, float]] = {}
76-
if self._parallel:
77-
with ThreadPoolExecutor() as executor:
78-
futures = {
79-
executor.submit(
80-
evaluator, query=query, response=response, context=context, ground_truth=ground_truth, **kwargs
81-
): evaluator
82-
for evaluator in self._evaluators
83-
}
84-
85-
# Collect results as they complete
86-
for future in as_completed(futures):
87-
results.update(future.result())
88-
else:
89-
for evaluator in self._evaluators:
90-
result = evaluator(query=query, response=response, context=context, ground_truth=ground_truth, **kwargs)
91-
results.update(result)
92-
93-
return results
97+
98+
return super().__call__(*args, **kwargs)

sdk/evaluation/azure-ai-evaluation/tests/e2etests/test_mass_evaluate.py

+30-30
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
RelevanceEvaluator,
2020
SimilarityEvaluator,
2121
GroundednessEvaluator,
22-
# QAEvaluator,
22+
QAEvaluator,
2323
ContentSafetyEvaluator,
2424
GroundednessProEvaluator,
2525
ProtectedMaterialEvaluator,
@@ -88,7 +88,7 @@ def test_evaluate_singleton_inputs(self, model_config, azure_cred, project_scope
8888
"fluency": FluencyEvaluator(model_config),
8989
"relevance": RelevanceEvaluator(model_config),
9090
"similarity": SimilarityEvaluator(model_config),
91-
# "qa": QAEvaluator(model_config),
91+
"qa": QAEvaluator(model_config),
9292
"grounded_pro": GroundednessProEvaluator(azure_cred, project_scope),
9393
"protected_material": ProtectedMaterialEvaluator(azure_cred, project_scope),
9494
"indirect_attack": IndirectAttackEvaluator(azure_cred, project_scope),
@@ -105,7 +105,7 @@ def test_evaluate_singleton_inputs(self, model_config, azure_cred, project_scope
105105
row_result_df = pd.DataFrame(result["rows"])
106106
metrics = result["metrics"]
107107

108-
assert len(row_result_df.keys()) == 48 # 63 with qa
108+
assert len(row_result_df.keys()) == 63
109109
assert len(row_result_df["inputs.query"]) == 3
110110
assert len(row_result_df["inputs.context"]) == 3
111111
assert len(row_result_df["inputs.response"]) == 3
@@ -154,23 +154,23 @@ def test_evaluate_singleton_inputs(self, model_config, azure_cred, project_scope
154154
assert len(row_result_df["outputs.content_safety.violence"]) == 3
155155
assert len(row_result_df["outputs.content_safety.violence_score"]) == 3
156156
assert len(row_result_df["outputs.content_safety.violence_reason"]) == 3
157-
# assert len(row_result_df["outputs.qa.f1_score"]) == 3
158-
# assert len(row_result_df["outputs.qa.groundedness"]) == 3
159-
# assert len(row_result_df["outputs.qa.gpt_groundedness"]) == 3
160-
# assert len(row_result_df["outputs.qa.groundedness_reason"]) == 3
161-
# assert len(row_result_df["outputs.qa.coherence"]) == 3
162-
# assert len(row_result_df["outputs.qa.gpt_coherence"]) == 3
163-
# assert len(row_result_df["outputs.qa.coherence_reason"]) == 3
164-
# assert len(row_result_df["outputs.qa.fluency"]) == 3
165-
# assert len(row_result_df["outputs.qa.gpt_fluency"]) == 3
166-
# assert len(row_result_df["outputs.qa.fluency_reason"]) == 3
167-
# assert len(row_result_df["outputs.qa.relevance"]) == 3
168-
# assert len(row_result_df["outputs.qa.gpt_relevance"]) == 3
169-
# assert len(row_result_df["outputs.qa.relevance_reason"]) == 3
170-
# assert len(row_result_df["outputs.qa.similarity"]) == 3
171-
# assert len(row_result_df["outputs.qa.gpt_similarity"]) == 3
157+
assert len(row_result_df["outputs.qa.f1_score"]) == 3
158+
assert len(row_result_df["outputs.qa.groundedness"]) == 3
159+
assert len(row_result_df["outputs.qa.gpt_groundedness"]) == 3
160+
assert len(row_result_df["outputs.qa.groundedness_reason"]) == 3
161+
assert len(row_result_df["outputs.qa.coherence"]) == 3
162+
assert len(row_result_df["outputs.qa.gpt_coherence"]) == 3
163+
assert len(row_result_df["outputs.qa.coherence_reason"]) == 3
164+
assert len(row_result_df["outputs.qa.fluency"]) == 3
165+
assert len(row_result_df["outputs.qa.gpt_fluency"]) == 3
166+
assert len(row_result_df["outputs.qa.fluency_reason"]) == 3
167+
assert len(row_result_df["outputs.qa.relevance"]) == 3
168+
assert len(row_result_df["outputs.qa.gpt_relevance"]) == 3
169+
assert len(row_result_df["outputs.qa.relevance_reason"]) == 3
170+
assert len(row_result_df["outputs.qa.similarity"]) == 3
171+
assert len(row_result_df["outputs.qa.gpt_similarity"]) == 3
172172

173-
assert len(metrics.keys()) == 28 # 39 with qa
173+
assert len(metrics.keys()) == 39
174174
assert metrics["f1_score.f1_score"] >= 0
175175
assert metrics["gleu.gleu_score"] >= 0
176176
assert metrics["bleu.bleu_score"] >= 0
@@ -199,17 +199,17 @@ def test_evaluate_singleton_inputs(self, model_config, azure_cred, project_scope
199199
assert metrics["protected_material.protected_material_defect_rate"] >= 0
200200
assert metrics["indirect_attack.xpia_defect_rate"] >= 0
201201
assert metrics["eci.eci_defect_rate"] >= 0
202-
# assert metrics["qa.f1_score"] >= 0
203-
# assert metrics["qa.groundedness"] >= 0
204-
# assert metrics["qa.gpt_groundedness"] >= 0
205-
# assert metrics["qa.coherence"] >= 0
206-
# assert metrics["qa.gpt_coherence"] >= 0
207-
# assert metrics["qa.fluency"] >= 0
208-
# assert metrics["qa.gpt_fluency"] >= 0
209-
# assert metrics["qa.relevance"] >= 0
210-
# assert metrics["qa.gpt_relevance"] >= 0
211-
# assert metrics["qa.similarity"] >= 0
212-
# assert metrics["qa.gpt_similarity"] >= 0
202+
assert metrics["qa.f1_score"] >= 0
203+
assert metrics["qa.groundedness"] >= 0
204+
assert metrics["qa.gpt_groundedness"] >= 0
205+
assert metrics["qa.coherence"] >= 0
206+
assert metrics["qa.gpt_coherence"] >= 0
207+
assert metrics["qa.fluency"] >= 0
208+
assert metrics["qa.gpt_fluency"] >= 0
209+
assert metrics["qa.relevance"] >= 0
210+
assert metrics["qa.gpt_relevance"] >= 0
211+
assert metrics["qa.similarity"] >= 0
212+
assert metrics["qa.gpt_similarity"] >= 0
213213

214214
def test_evaluate_conversation(self, model_config, data_convo_file, azure_cred, project_scope):
215215
evaluators = {

0 commit comments

Comments
 (0)