15
15
16
16
from pydantic import BaseModel , ConfigDict , Field
17
17
18
+ from cleanlab_codex .internal .tlm import TLM
18
19
from cleanlab_codex .internal .utils import generate_pydantic_model_docstring
19
20
from cleanlab_codex .types .response_validation import (
20
21
AggregatedResponseValidationResult ,
21
22
SingleResponseValidationResult ,
22
23
)
23
- from cleanlab_codex .types .tlm import TLM
24
+ from cleanlab_codex .types .tlm import TLMConfig
24
25
from cleanlab_codex .utils .errors import MissingDependencyError
25
26
from cleanlab_codex .utils .prompt import default_format_prompt
26
27
30
31
_DEFAULT_FALLBACK_SIMILARITY_THRESHOLD : float = 0.7
31
32
_DEFAULT_TRUSTWORTHINESS_THRESHOLD : float = 0.5
32
33
_DEFAULT_UNHELPFULNESS_CONFIDENCE_THRESHOLD : float = 0.5
34
+ _DEFAULT_TLM_CONFIG : TLMConfig = TLMConfig ()
33
35
34
36
Query = str
35
37
Context = str
@@ -77,13 +79,12 @@ class BadResponseDetectionConfig(BaseModel):
77
79
)
78
80
79
81
# Shared config (for untrustworthiness and unhelpfulness checks)
80
- tlm : Optional [ TLM ] = Field (
81
- default = None ,
82
- description = "TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks) ." ,
82
+ tlm_config : TLMConfig = Field (
83
+ default = _DEFAULT_TLM_CONFIG ,
84
+ description = "TLM model configuration to use for untrustworthiness and unhelpfulness checks." ,
83
85
)
84
86
85
87
86
- # hack to generate better documentation for help.cleanlab.ai
87
88
BadResponseDetectionConfig .__doc__ = f"""
88
89
{ BadResponseDetectionConfig .__doc__ }
89
90
@@ -99,10 +100,11 @@ def is_bad_response(
99
100
context : Optional [str ] = None ,
100
101
query : Optional [str ] = None ,
101
102
config : Union [BadResponseDetectionConfig , Dict [str , Any ]] = _DEFAULT_CONFIG ,
103
+ codex_access_key : Optional [str ] = None ,
102
104
) -> AggregatedResponseValidationResult :
103
105
"""Run a series of checks to determine if a response is bad.
104
106
105
- The function returns a `AggregatedResponseValidationResult` object containing results from multiple validation checks.
107
+ The function returns an `AggregatedResponseValidationResult` object containing results from multiple validation checks.
106
108
If any check fails (detects an issue), the AggregatedResponseValidationResult will evaluate to `True` when used in a boolean context.
107
109
This means code like `if is_bad_response(...)` will enter the if-block when problems are detected.
108
110
@@ -146,28 +148,30 @@ def is_bad_response(
146
148
)
147
149
)
148
150
149
- can_run_untrustworthy_check = query is not None and context is not None and config .tlm is not None
151
+ can_run_untrustworthy_check = query is not None and context is not None and config .tlm_config is not None
150
152
if can_run_untrustworthy_check :
151
153
# The if condition guarantees these are not None
152
154
validation_checks .append (
153
155
lambda : is_untrustworthy_response (
154
156
response = response ,
155
157
context = cast (str , context ),
156
158
query = cast (str , query ),
157
- tlm = cast ( TLM , config .tlm ) ,
159
+ tlm_config = config .tlm_config ,
158
160
trustworthiness_threshold = config .trustworthiness_threshold ,
159
161
format_prompt = config .format_prompt ,
162
+ codex_access_key = codex_access_key ,
160
163
)
161
164
)
162
165
163
- can_run_unhelpful_check = query is not None and config .tlm is not None
166
+ can_run_unhelpful_check = query is not None and config .tlm_config is not None
164
167
if can_run_unhelpful_check :
165
168
validation_checks .append (
166
169
lambda : is_unhelpful_response (
167
170
response = response ,
168
171
query = cast (str , query ),
169
- tlm = cast ( TLM , config .tlm ) ,
172
+ tlm_config = config .tlm_config ,
170
173
confidence_score_threshold = config .unhelpfulness_confidence_threshold ,
174
+ codex_access_key = codex_access_key ,
171
175
)
172
176
)
173
177
@@ -238,9 +242,11 @@ def is_untrustworthy_response(
238
242
response : str ,
239
243
context : str ,
240
244
query : str ,
241
- tlm : TLM ,
245
+ tlm_config : TLMConfig = _DEFAULT_TLM_CONFIG ,
242
246
trustworthiness_threshold : float = _DEFAULT_TRUSTWORTHINESS_THRESHOLD ,
243
247
format_prompt : Callable [[str , str ], str ] = default_format_prompt ,
248
+ * ,
249
+ codex_access_key : Optional [str ] = None ,
244
250
) -> SingleResponseValidationResult :
245
251
"""Check if a response is untrustworthy.
246
252
@@ -252,7 +258,7 @@ def is_untrustworthy_response(
252
258
response (str): The response to check from the assistant.
253
259
context (str): The context information available for answering the query.
254
260
query (str): The user's question or request.
255
- tlm (TLM ): The TLM model to use for evaluation.
261
+ tlm_config (TLMConfig ): The TLM configuration to use for evaluation.
256
262
trustworthiness_threshold (float): Score threshold (0.0-1.0) under which a response is considered untrustworthy.
257
263
Lower values allow less trustworthy responses. Default 0.5 means responses with scores less than 0.5 are considered untrustworthy.
258
264
format_prompt (Callable[[str, str], str]): Function that takes (query, context) and returns a formatted prompt string.
@@ -266,8 +272,9 @@ def is_untrustworthy_response(
266
272
response = response ,
267
273
context = context ,
268
274
query = query ,
269
- tlm = tlm ,
275
+ tlm_config = tlm_config ,
270
276
format_prompt = format_prompt ,
277
+ codex_access_key = codex_access_key ,
271
278
)
272
279
return SingleResponseValidationResult (
273
280
name = "untrustworthy" ,
@@ -281,8 +288,10 @@ def score_untrustworthy_response(
281
288
response : str ,
282
289
context : str ,
283
290
query : str ,
284
- tlm : TLM ,
291
+ tlm_config : TLMConfig = _DEFAULT_TLM_CONFIG ,
285
292
format_prompt : Callable [[str , str ], str ] = default_format_prompt ,
293
+ * ,
294
+ codex_access_key : Optional [str ] = None ,
286
295
) -> float :
287
296
"""Scores a response's trustworthiness using [TLM](/tlm), given a context and query.
288
297
@@ -298,24 +307,20 @@ def score_untrustworthy_response(
298
307
Returns:
299
308
float: The score of the response, between 0.0 and 1.0. A lower score indicates the response is less trustworthy.
300
309
"""
301
- try :
302
- from cleanlab_tlm import TLM # noqa: F401
303
- except ImportError as e :
304
- raise MissingDependencyError (
305
- import_name = e .name or "cleanlab_tlm" ,
306
- package_name = "cleanlab-tlm" ,
307
- package_url = "https://github.com/cleanlab/cleanlab-tlm" ,
308
- ) from e
309
310
prompt = format_prompt (query , context )
310
- result = tlm .get_trustworthiness_score (prompt , response )
311
- return float (result ["trustworthiness_score" ])
311
+ result = TLM .from_config (tlm_config , codex_access_key = codex_access_key ).get_trustworthiness_score (
312
+ prompt , response = response
313
+ )
314
+ return float (result .trustworthiness_score )
312
315
313
316
314
317
def is_unhelpful_response (
315
318
response : str ,
316
319
query : str ,
317
- tlm : TLM ,
320
+ tlm_config : TLMConfig = _DEFAULT_TLM_CONFIG ,
318
321
confidence_score_threshold : float = _DEFAULT_UNHELPFULNESS_CONFIDENCE_THRESHOLD ,
322
+ * ,
323
+ codex_access_key : Optional [str ] = None ,
319
324
) -> SingleResponseValidationResult :
320
325
"""Check if a response is unhelpful by asking [TLM](/tlm) to evaluate it.
321
326
@@ -327,14 +332,14 @@ def is_unhelpful_response(
327
332
Args:
328
333
response (str): The response to check.
329
334
query (str): User query that will be used to evaluate if the response is helpful.
330
- tlm (TLM ): The TLM model to use for evaluation.
335
+ tlm_config (TLMConfig ): The configuration
331
336
confidence_score_threshold (float): Confidence threshold (0.0-1.0) above which a response is considered unhelpful.
332
337
E.g. if confidence_score_threshold is 0.5, then responses with scores higher than 0.5 are considered unhelpful.
333
338
334
339
Returns:
335
340
SingleResponseValidationResult: The results of the validation check.
336
341
"""
337
- score : float = score_unhelpful_response (response , query , tlm )
342
+ score : float = score_unhelpful_response (response , query , tlm_config , codex_access_key = codex_access_key )
338
343
339
344
# Current implementation of `score_unhelpful_response` produces a score where a higher value means the response if more likely to be unhelpful
340
345
# Changing the TLM prompt used in `score_unhelpful_response` may require restructuring the logic for `fails_check` and potentially adjusting
@@ -350,27 +355,20 @@ def is_unhelpful_response(
350
355
def score_unhelpful_response (
351
356
response : str ,
352
357
query : str ,
353
- tlm : TLM ,
358
+ tlm_config : TLMConfig = _DEFAULT_TLM_CONFIG ,
359
+ * ,
360
+ codex_access_key : Optional [str ] = None ,
354
361
) -> float :
355
362
"""Scores a response's unhelpfulness using [TLM](/tlm), given a query.
356
363
357
364
Args:
358
365
response (str): The response to check.
359
366
query (str): User query that will be used to evaluate if the response is helpful.
360
- tlm (TLM ): The TLM model to use for evaluation.
367
+ tlm_config (TLMConfig ): The TLM model to use for evaluation.
361
368
362
369
Returns:
363
370
float: The score of the response, between 0.0 and 1.0. A higher score corresponds to a less helpful response.
364
371
"""
365
- try :
366
- from cleanlab_tlm import TLM # noqa: F401
367
- except ImportError as e :
368
- raise MissingDependencyError (
369
- import_name = e .name or "cleanlab_tlm" ,
370
- package_name = "cleanlab-tlm" ,
371
- package_url = "https://github.com/cleanlab/cleanlab-tlm" ,
372
- ) from e
373
-
374
372
# IMPORTANT: The current implementation couples three things that must stay in sync:
375
373
# 1. The question phrasing ("is unhelpful?")
376
374
# 2. The expected_unhelpful_response ("Yes")
@@ -405,5 +403,7 @@ def score_unhelpful_response(
405
403
f"AI Assistant Response: { response } \n \n "
406
404
f"{ question } "
407
405
)
408
- result = tlm .get_trustworthiness_score (prompt , response = expected_unhelpful_response )
409
- return float (result ["trustworthiness_score" ])
406
+ result = TLM .from_config (tlm_config , codex_access_key = codex_access_key ).get_trustworthiness_score (
407
+ prompt , response = expected_unhelpful_response
408
+ )
409
+ return float (result .trustworthiness_score )
0 commit comments