Skip to content

Commit fcd4291

Browse files
authored
refactor: auto extract tool properties from fn def (#24)
* auto extract tool properties from fn def * update docstring
1 parent 31c24b9 commit fcd4291

File tree

11 files changed

+151
-119
lines changed

11 files changed

+151
-119
lines changed

src/cleanlab_codex/codex_tool.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,19 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, ClassVar, Optional
5+
from typing import Any, Optional
6+
7+
from typing_extensions import Annotated
68

79
from cleanlab_codex.project import Project
10+
from cleanlab_codex.utils.function import pydantic_model_from_function, required_properties_from_model
811

912

1013
class CodexTool:
1114
"""A tool that connects to a Codex project to answer questions."""
1215

1316
_tool_name = "ask_advisor"
14-
_tool_description = "Asks an all-knowing advisor this query in cases where it cannot be answered from the provided Context. If the answer is available, this returns None."
15-
_tool_properties: ClassVar[dict[str, Any]] = {
16-
"question": {
17-
"type": "string",
18-
"description": "The question to ask the advisor. This should be the same as the original user question, except in cases where the user question is missing information that could be additionally clarified.",
19-
}
20-
}
21-
_tool_requirements: ClassVar[list[str]] = ["question"]
17+
_tool_description = "Asks an all-knowing advisor this query in cases where it cannot be answered from the provided Context. If the answer is unavailable, this returns None."
2218
DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question."
2319

2420
def __init__(
@@ -29,6 +25,9 @@ def __init__(
2925
):
3026
self._project = project
3127
self._fallback_answer = fallback_answer
28+
self._tool_function_schema = pydantic_model_from_function(self._tool_name, self.query)
29+
self._tool_properties = self._tool_function_schema.model_json_schema()["properties"]
30+
self._tool_requirements = required_properties_from_model(self._tool_function_schema)
3231

3332
@classmethod
3433
def from_access_key(
@@ -86,11 +85,17 @@ def fallback_answer(self, value: Optional[str]) -> None:
8685
"""Sets the fallback answer to use if the Codex project cannot answer the question."""
8786
self._fallback_answer = value
8887

89-
def query(self, question: str) -> Optional[str]:
90-
"""Asks an all-knowing advisor this question in cases where it cannot be answered from the provided Context.
88+
def query(
89+
self,
90+
question: Annotated[
91+
str,
92+
"The question to ask the advisor. This should be the same as the original user question, except in cases where the user question is missing information that could be additionally clarified.",
93+
],
94+
) -> Optional[str]:
95+
"""Asks an all-knowing advisor this question in cases where it cannot be answered from the provided Context. If the answer is unavailable, this returns a fallback answer or None.
9196
9297
Args:
93-
question: The question to ask the advisor. This should be the same as the original user question, except in cases where the user question is missing information that could be additionally clarified.
98+
question (str): The question to ask the advisor. This should be the same as the original user question, except in cases where the user question is missing information that could be additionally clarified.
9499
95100
Returns:
96101
The answer to the question if available. If no answer is available, this returns a fallback answer or None.
@@ -130,17 +135,11 @@ def to_llamaindex_tool(self) -> Any:
130135
"""
131136
from llama_index.core.tools import FunctionTool
132137

133-
from cleanlab_codex.utils.llamaindex import get_function_schema
134-
135138
return FunctionTool.from_defaults(
136139
fn=self.query,
137140
name=self._tool_name,
138141
description=self._tool_description,
139-
fn_schema=get_function_schema(
140-
name=self._tool_name,
141-
func=self.query,
142-
tool_properties=self._tool_properties,
143-
),
142+
fn_schema=self._tool_function_schema,
144143
)
145144

146145
def to_langchain_tool(self) -> Any:
@@ -150,17 +149,11 @@ def to_langchain_tool(self) -> Any:
150149
"""
151150
from langchain_core.tools.structured import StructuredTool
152151

153-
from cleanlab_codex.utils.langchain import create_args_schema
154-
155152
return StructuredTool.from_function(
156153
func=self.query,
157154
name=self._tool_name,
158155
description=self._tool_description,
159-
args_schema=create_args_schema(
160-
name=self._tool_name,
161-
func=self.query,
162-
tool_properties=self._tool_properties,
163-
),
156+
args_schema=self._tool_function_schema,
164157
)
165158

166159
def to_aws_converse_tool(self) -> Any:

src/cleanlab_codex/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from cleanlab_codex.utils.aws import Tool as AWSConverseTool
22
from cleanlab_codex.utils.aws import ToolSpec as AWSToolSpec
33
from cleanlab_codex.utils.aws import format_as_aws_converse_tool
4+
from cleanlab_codex.utils.function import FunctionParameters
45
from cleanlab_codex.utils.openai import Function as OpenAIFunction
56
from cleanlab_codex.utils.openai import Tool as OpenAITool
67
from cleanlab_codex.utils.openai import format_as_openai_tool
7-
from cleanlab_codex.utils.types import FunctionParameters
88

99
__all__ = [
1010
"FunctionParameters",

src/cleanlab_codex/utils/aws.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pydantic import BaseModel, Field
66

7-
from cleanlab_codex.utils.types import FunctionParameters
7+
from cleanlab_codex.utils.function import FunctionParameters
88

99

1010
class ToolSpec(BaseModel):
@@ -27,6 +27,11 @@ def format_as_aws_converse_tool(
2727
toolSpec=ToolSpec(
2828
name=tool_name,
2929
description=tool_description,
30-
inputSchema={"json": FunctionParameters(properties=tool_properties, required=required_properties)},
30+
inputSchema={
31+
"json": FunctionParameters(
32+
properties=tool_properties,
33+
required=required_properties,
34+
)
35+
},
3136
)
3237
).model_dump(by_alias=True)

src/cleanlab_codex/utils/function.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from inspect import signature
2+
from typing import Any, Callable, Dict, List, Literal, Type
3+
4+
from pydantic import BaseModel, Field, create_model
5+
from typing_extensions import Annotated, get_args, get_origin
6+
7+
8+
class Property(BaseModel):
9+
type: Literal["string", "number", "integer", "boolean", "array", "object"]
10+
description: str
11+
12+
13+
class FunctionParameters(BaseModel):
14+
type: Literal["object"] = "object"
15+
properties: Dict[str, Property]
16+
required: List[str]
17+
18+
19+
def pydantic_model_from_function(name: str, func: Callable[..., Any]) -> Type[BaseModel]:
20+
"""
21+
Create a pydantic model representing a function's schema.
22+
23+
For example, a function with the following signature:
24+
25+
```python
26+
def my_function(
27+
a: Annotated[int, "This is an integer"], b: str = "default"
28+
) -> None: ...
29+
```
30+
31+
will be represented by the following pydantic model when `name="my_function"`:
32+
33+
```python
34+
class my_function(BaseModel):
35+
a: int = Field(description="This is an integer")
36+
b: str = "default"
37+
```
38+
39+
Args:
40+
name: The name for the pydantic model.
41+
func: The function to create a schema for.
42+
43+
Returns:
44+
A pydantic model representing the function's schema.
45+
"""
46+
fields = {}
47+
params = signature(func).parameters
48+
49+
for param_name, param in params.items():
50+
param_type = param.annotation
51+
if isinstance(param_type, str):
52+
param_type = eval(param_type) # noqa: S307
53+
54+
param_default = param.default
55+
description = None
56+
57+
if get_origin(param_type) is Annotated:
58+
args = get_args(param_type)
59+
param_type = args[0]
60+
if isinstance(args[1], str):
61+
description = args[1]
62+
63+
if param_type is param.empty:
64+
param_type = Any
65+
66+
if param_default is param.empty:
67+
fields[param_name] = (param_type, Field(description=description))
68+
else:
69+
fields[param_name] = (
70+
param_type,
71+
Field(default=param_default, description=description),
72+
)
73+
74+
return create_model(name, **fields) # type: ignore
75+
76+
77+
def required_properties_from_model(model: Type[BaseModel]) -> List[str]:
78+
"""Returns a list of required properties from a pydantic model."""
79+
return [name for name, field in model.model_fields.items() if field.is_required()]

src/cleanlab_codex/utils/langchain.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

src/cleanlab_codex/utils/llamaindex.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/cleanlab_codex/utils/openai.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pydantic import BaseModel
66

7-
from cleanlab_codex.utils.types import FunctionParameters
7+
from cleanlab_codex.utils.function import FunctionParameters
88

99

1010
class Tool(BaseModel):
@@ -28,6 +28,9 @@ def format_as_openai_tool(
2828
function=Function(
2929
name=tool_name,
3030
description=tool_description,
31-
parameters=FunctionParameters(properties=tool_properties, required=required_properties),
31+
parameters=FunctionParameters(
32+
properties=tool_properties,
33+
required=required_properties,
34+
),
3235
)
3336
).model_dump()

src/cleanlab_codex/utils/types.py

Lines changed: 0 additions & 14 deletions
This file was deleted.
File renamed without changes.

tests/utils/__init__.py

Whitespace-only changes.

tests/utils/test_function.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Any
2+
3+
from typing_extensions import Annotated
4+
5+
from cleanlab_codex.utils.function import pydantic_model_from_function
6+
7+
8+
def test_function_schema_with_annotated_params() -> None:
9+
def function_with_annotated_params(
10+
a: Annotated[str, "This is a string"], # noqa: ARG001
11+
) -> None: ...
12+
13+
fn_schema = pydantic_model_from_function("test_function", function_with_annotated_params)
14+
assert fn_schema.model_json_schema()["title"] == "test_function"
15+
assert fn_schema.model_fields["a"].annotation is str
16+
assert fn_schema.model_fields["a"].description == "This is a string"
17+
assert fn_schema.model_fields["a"].is_required()
18+
19+
20+
def test_function_schema_without_annotations() -> None:
21+
def function_without_annotations(a) -> None: # type: ignore # noqa: ARG001
22+
...
23+
24+
fn_schema = pydantic_model_from_function("test_function", function_without_annotations)
25+
assert fn_schema.model_json_schema()["title"] == "test_function"
26+
assert fn_schema.model_fields["a"].annotation is Any # type: ignore[comparison-overlap]
27+
assert fn_schema.model_fields["a"].is_required()
28+
assert fn_schema.model_fields["a"].description is None
29+
30+
31+
def test_function_schema_with_default_param() -> None:
32+
def function_with_default_param(a: int = 1) -> None: # noqa: ARG001
33+
...
34+
35+
fn_schema = pydantic_model_from_function("test_function", function_with_default_param)
36+
assert fn_schema.model_json_schema()["title"] == "test_function"
37+
assert fn_schema.model_fields["a"].annotation is int
38+
assert fn_schema.model_fields["a"].default == 1
39+
assert not fn_schema.model_fields["a"].is_required()
40+
assert fn_schema.model_fields["a"].description is None

0 commit comments

Comments
 (0)