Skip to content

Commit e13ac3c

Browse files
authored
update response validation to use TLM via Codex (#58)
1 parent 5517d0a commit e13ac3c

9 files changed

+294
-120
lines changed

CHANGELOG.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [1.0.3] - 2025-03-11
11+
12+
- Update response validation methods for Codex as backup to use TLM through Codex API instead of requiring separate TLM API key.
13+
1014
## [1.0.2] - 2025-03-07
1115

1216
- Extract scores and metadata from detection functions in `response_validation.py`.
@@ -21,7 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2125

2226
- Initial release of the `cleanlab-codex` client library.
2327

24-
[Unreleased]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.2...HEAD
28+
[Unreleased]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.3...HEAD
29+
[1.0.3]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.2...v1.0.3
2530
[1.0.2]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.1...v1.0.2
2631
[1.0.1]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.0...v1.0.1
2732
[1.0.0]: https://github.com/cleanlab/cleanlab-codex/compare/267a93300f77c94e215d7697223931e7926cad9e...v1.0.0

pyproject.toml

+1-4
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ classifiers = [
2525
"Programming Language :: Python :: Implementation :: PyPy",
2626
]
2727
dependencies = [
28-
"codex-sdk==0.1.0a9",
28+
"codex-sdk==0.1.0a12",
2929
"pydantic>=2.0.0, <3",
30-
"cleanlab-tlm",
3130
]
3231

3332
[project.urls]
@@ -44,7 +43,6 @@ extra-dependencies = [
4443
"pytest",
4544
"llama-index-core",
4645
"smolagents; python_version >= '3.10'",
47-
"cleanlab-tlm",
4846
"thefuzz",
4947
"langchain-core",
5048
]
@@ -62,7 +60,6 @@ allow-direct-references = true
6260
extra-dependencies = [
6361
"llama-index-core",
6462
"smolagents; python_version >= '3.10'",
65-
"cleanlab-tlm",
6663
"thefuzz",
6764
"langchain-core",
6865
]

src/cleanlab_codex/__about__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# SPDX-License-Identifier: MIT
2-
__version__ = "1.0.2"
2+
__version__ = "1.0.3"

src/cleanlab_codex/internal/tlm.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict, List, Optional
4+
5+
from cleanlab_codex.internal.sdk_client import (
6+
MissingAuthKeyError,
7+
client_from_access_key,
8+
client_from_api_key,
9+
)
10+
from cleanlab_codex.types.tlm import (
11+
TLMConfig,
12+
TLMOptions,
13+
TLMPromptResponse,
14+
TLMQualityPreset,
15+
TLMScoreResponse,
16+
)
17+
18+
19+
class TLM:
20+
def __init__(
21+
self,
22+
quality_preset: Optional[TLMQualityPreset] = None,
23+
*,
24+
task: Optional[str] = None,
25+
options: Optional[TLMOptions] = None,
26+
codex_access_key: Optional[str] = None,
27+
):
28+
try:
29+
self._sdk_client = client_from_access_key(key=codex_access_key)
30+
except MissingAuthKeyError:
31+
self._sdk_client = client_from_api_key()
32+
33+
self._tlm_kwargs: Dict[str, Any] = {}
34+
if quality_preset:
35+
self._tlm_kwargs["quality_preset"] = quality_preset
36+
if task:
37+
self._tlm_kwargs["task"] = task
38+
if options:
39+
self._tlm_kwargs["options"] = options
40+
41+
@classmethod
42+
def from_config(cls, config: TLMConfig, *, codex_access_key: Optional[str] = None) -> TLM:
43+
return cls(**config.model_dump(), codex_access_key=codex_access_key)
44+
45+
def prompt(
46+
self,
47+
prompt: str,
48+
*,
49+
constrain_outputs: Optional[List[str]] = None,
50+
) -> TLMPromptResponse:
51+
return TLMPromptResponse.model_validate(
52+
self._sdk_client.tlm.prompt(
53+
prompt=prompt,
54+
constrain_outputs=constrain_outputs,
55+
**self._tlm_kwargs,
56+
).model_dump()
57+
)
58+
59+
def get_trustworthiness_score(
60+
self,
61+
prompt: str,
62+
response: str,
63+
*,
64+
constrain_outputs: Optional[List[str]] = None,
65+
) -> TLMScoreResponse:
66+
return TLMScoreResponse.model_validate(
67+
self._sdk_client.tlm.score(
68+
prompt=prompt,
69+
response=response,
70+
constrain_outputs=constrain_outputs,
71+
**self._tlm_kwargs,
72+
).model_dump()
73+
)

src/cleanlab_codex/internal/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from typing import TYPE_CHECKING as _TYPE_CHECKING
44

5+
from pydantic_core import PydanticUndefinedType
56
from typing_extensions import get_origin, get_type_hints, is_typeddict
67

78
if _TYPE_CHECKING:
@@ -40,7 +41,7 @@ def generate_pydantic_model_docstring(cls: type[BaseModel], name: str) -> str:
4041
formatted_annotations = "\n ".join(
4142
format_annotation_from_field_info(field_name, field_info) for field_name, field_info in cls.model_fields.items()
4243
)
43-
formatted_fields = "\n ".join(
44+
formatted_fields = "\n".join(
4445
format_pydantic_field_docstring(field_name, field_info) for field_name, field_info in cls.model_fields.items()
4546
)
4647
return f"""
@@ -59,7 +60,7 @@ def format_annotation_from_field_info(field_name: str, field_info: FieldInfo) ->
5960
annotation = field_name
6061
if field_info.annotation:
6162
annotation += f": '{annotation_to_str(field_info.annotation)}'"
62-
if field_info.default:
63+
if field_info.default and not isinstance(field_info.default, PydanticUndefinedType):
6364
annotation += f" = {field_info.default}"
6465
return annotation
6566

src/cleanlab_codex/response_validation.py

+40-40
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515

1616
from pydantic import BaseModel, ConfigDict, Field
1717

18+
from cleanlab_codex.internal.tlm import TLM
1819
from cleanlab_codex.internal.utils import generate_pydantic_model_docstring
1920
from cleanlab_codex.types.response_validation import (
2021
AggregatedResponseValidationResult,
2122
SingleResponseValidationResult,
2223
)
23-
from cleanlab_codex.types.tlm import TLM
24+
from cleanlab_codex.types.tlm import TLMConfig
2425
from cleanlab_codex.utils.errors import MissingDependencyError
2526
from cleanlab_codex.utils.prompt import default_format_prompt
2627

@@ -30,6 +31,7 @@
3031
_DEFAULT_FALLBACK_SIMILARITY_THRESHOLD: float = 0.7
3132
_DEFAULT_TRUSTWORTHINESS_THRESHOLD: float = 0.5
3233
_DEFAULT_UNHELPFULNESS_CONFIDENCE_THRESHOLD: float = 0.5
34+
_DEFAULT_TLM_CONFIG: TLMConfig = TLMConfig()
3335

3436
Query = str
3537
Context = str
@@ -77,13 +79,12 @@ class BadResponseDetectionConfig(BaseModel):
7779
)
7880

7981
# Shared config (for untrustworthiness and unhelpfulness checks)
80-
tlm: Optional[TLM] = Field(
81-
default=None,
82-
description="TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks).",
82+
tlm_config: TLMConfig = Field(
83+
default=_DEFAULT_TLM_CONFIG,
84+
description="TLM model configuration to use for untrustworthiness and unhelpfulness checks.",
8385
)
8486

8587

86-
# hack to generate better documentation for help.cleanlab.ai
8788
BadResponseDetectionConfig.__doc__ = f"""
8889
{BadResponseDetectionConfig.__doc__}
8990
@@ -99,10 +100,11 @@ def is_bad_response(
99100
context: Optional[str] = None,
100101
query: Optional[str] = None,
101102
config: Union[BadResponseDetectionConfig, Dict[str, Any]] = _DEFAULT_CONFIG,
103+
codex_access_key: Optional[str] = None,
102104
) -> AggregatedResponseValidationResult:
103105
"""Run a series of checks to determine if a response is bad.
104106
105-
The function returns a `AggregatedResponseValidationResult` object containing results from multiple validation checks.
107+
The function returns an `AggregatedResponseValidationResult` object containing results from multiple validation checks.
106108
If any check fails (detects an issue), the AggregatedResponseValidationResult will evaluate to `True` when used in a boolean context.
107109
This means code like `if is_bad_response(...)` will enter the if-block when problems are detected.
108110
@@ -146,28 +148,30 @@ def is_bad_response(
146148
)
147149
)
148150

149-
can_run_untrustworthy_check = query is not None and context is not None and config.tlm is not None
151+
can_run_untrustworthy_check = query is not None and context is not None and config.tlm_config is not None
150152
if can_run_untrustworthy_check:
151153
# The if condition guarantees these are not None
152154
validation_checks.append(
153155
lambda: is_untrustworthy_response(
154156
response=response,
155157
context=cast(str, context),
156158
query=cast(str, query),
157-
tlm=cast(TLM, config.tlm),
159+
tlm_config=config.tlm_config,
158160
trustworthiness_threshold=config.trustworthiness_threshold,
159161
format_prompt=config.format_prompt,
162+
codex_access_key=codex_access_key,
160163
)
161164
)
162165

163-
can_run_unhelpful_check = query is not None and config.tlm is not None
166+
can_run_unhelpful_check = query is not None and config.tlm_config is not None
164167
if can_run_unhelpful_check:
165168
validation_checks.append(
166169
lambda: is_unhelpful_response(
167170
response=response,
168171
query=cast(str, query),
169-
tlm=cast(TLM, config.tlm),
172+
tlm_config=config.tlm_config,
170173
confidence_score_threshold=config.unhelpfulness_confidence_threshold,
174+
codex_access_key=codex_access_key,
171175
)
172176
)
173177

@@ -238,9 +242,11 @@ def is_untrustworthy_response(
238242
response: str,
239243
context: str,
240244
query: str,
241-
tlm: TLM,
245+
tlm_config: TLMConfig = _DEFAULT_TLM_CONFIG,
242246
trustworthiness_threshold: float = _DEFAULT_TRUSTWORTHINESS_THRESHOLD,
243247
format_prompt: Callable[[str, str], str] = default_format_prompt,
248+
*,
249+
codex_access_key: Optional[str] = None,
244250
) -> SingleResponseValidationResult:
245251
"""Check if a response is untrustworthy.
246252
@@ -252,7 +258,7 @@ def is_untrustworthy_response(
252258
response (str): The response to check from the assistant.
253259
context (str): The context information available for answering the query.
254260
query (str): The user's question or request.
255-
tlm (TLM): The TLM model to use for evaluation.
261+
tlm_config (TLMConfig): The TLM configuration to use for evaluation.
256262
trustworthiness_threshold (float): Score threshold (0.0-1.0) under which a response is considered untrustworthy.
257263
Lower values allow less trustworthy responses. Default 0.5 means responses with scores less than 0.5 are considered untrustworthy.
258264
format_prompt (Callable[[str, str], str]): Function that takes (query, context) and returns a formatted prompt string.
@@ -266,8 +272,9 @@ def is_untrustworthy_response(
266272
response=response,
267273
context=context,
268274
query=query,
269-
tlm=tlm,
275+
tlm_config=tlm_config,
270276
format_prompt=format_prompt,
277+
codex_access_key=codex_access_key,
271278
)
272279
return SingleResponseValidationResult(
273280
name="untrustworthy",
@@ -281,8 +288,10 @@ def score_untrustworthy_response(
281288
response: str,
282289
context: str,
283290
query: str,
284-
tlm: TLM,
291+
tlm_config: TLMConfig = _DEFAULT_TLM_CONFIG,
285292
format_prompt: Callable[[str, str], str] = default_format_prompt,
293+
*,
294+
codex_access_key: Optional[str] = None,
286295
) -> float:
287296
"""Scores a response's trustworthiness using [TLM](/tlm), given a context and query.
288297
@@ -298,24 +307,20 @@ def score_untrustworthy_response(
298307
Returns:
299308
float: The score of the response, between 0.0 and 1.0. A lower score indicates the response is less trustworthy.
300309
"""
301-
try:
302-
from cleanlab_tlm import TLM # noqa: F401
303-
except ImportError as e:
304-
raise MissingDependencyError(
305-
import_name=e.name or "cleanlab_tlm",
306-
package_name="cleanlab-tlm",
307-
package_url="https://github.com/cleanlab/cleanlab-tlm",
308-
) from e
309310
prompt = format_prompt(query, context)
310-
result = tlm.get_trustworthiness_score(prompt, response)
311-
return float(result["trustworthiness_score"])
311+
result = TLM.from_config(tlm_config, codex_access_key=codex_access_key).get_trustworthiness_score(
312+
prompt, response=response
313+
)
314+
return float(result.trustworthiness_score)
312315

313316

314317
def is_unhelpful_response(
315318
response: str,
316319
query: str,
317-
tlm: TLM,
320+
tlm_config: TLMConfig = _DEFAULT_TLM_CONFIG,
318321
confidence_score_threshold: float = _DEFAULT_UNHELPFULNESS_CONFIDENCE_THRESHOLD,
322+
*,
323+
codex_access_key: Optional[str] = None,
319324
) -> SingleResponseValidationResult:
320325
"""Check if a response is unhelpful by asking [TLM](/tlm) to evaluate it.
321326
@@ -327,14 +332,14 @@ def is_unhelpful_response(
327332
Args:
328333
response (str): The response to check.
329334
query (str): User query that will be used to evaluate if the response is helpful.
330-
tlm (TLM): The TLM model to use for evaluation.
335+
tlm_config (TLMConfig): The configuration
331336
confidence_score_threshold (float): Confidence threshold (0.0-1.0) above which a response is considered unhelpful.
332337
E.g. if confidence_score_threshold is 0.5, then responses with scores higher than 0.5 are considered unhelpful.
333338
334339
Returns:
335340
SingleResponseValidationResult: The results of the validation check.
336341
"""
337-
score: float = score_unhelpful_response(response, query, tlm)
342+
score: float = score_unhelpful_response(response, query, tlm_config, codex_access_key=codex_access_key)
338343

339344
# Current implementation of `score_unhelpful_response` produces a score where a higher value means the response if more likely to be unhelpful
340345
# Changing the TLM prompt used in `score_unhelpful_response` may require restructuring the logic for `fails_check` and potentially adjusting
@@ -350,27 +355,20 @@ def is_unhelpful_response(
350355
def score_unhelpful_response(
351356
response: str,
352357
query: str,
353-
tlm: TLM,
358+
tlm_config: TLMConfig = _DEFAULT_TLM_CONFIG,
359+
*,
360+
codex_access_key: Optional[str] = None,
354361
) -> float:
355362
"""Scores a response's unhelpfulness using [TLM](/tlm), given a query.
356363
357364
Args:
358365
response (str): The response to check.
359366
query (str): User query that will be used to evaluate if the response is helpful.
360-
tlm (TLM): The TLM model to use for evaluation.
367+
tlm_config (TLMConfig): The TLM model to use for evaluation.
361368
362369
Returns:
363370
float: The score of the response, between 0.0 and 1.0. A higher score corresponds to a less helpful response.
364371
"""
365-
try:
366-
from cleanlab_tlm import TLM # noqa: F401
367-
except ImportError as e:
368-
raise MissingDependencyError(
369-
import_name=e.name or "cleanlab_tlm",
370-
package_name="cleanlab-tlm",
371-
package_url="https://github.com/cleanlab/cleanlab-tlm",
372-
) from e
373-
374372
# IMPORTANT: The current implementation couples three things that must stay in sync:
375373
# 1. The question phrasing ("is unhelpful?")
376374
# 2. The expected_unhelpful_response ("Yes")
@@ -405,5 +403,7 @@ def score_unhelpful_response(
405403
f"AI Assistant Response: {response}\n\n"
406404
f"{question}"
407405
)
408-
result = tlm.get_trustworthiness_score(prompt, response=expected_unhelpful_response)
409-
return float(result["trustworthiness_score"])
406+
result = TLM.from_config(tlm_config, codex_access_key=codex_access_key).get_trustworthiness_score(
407+
prompt, response=expected_unhelpful_response
408+
)
409+
return float(result.trustworthiness_score)

0 commit comments

Comments
 (0)