From b7e7fdee5585b19a9e059adf7599733359220397 Mon Sep 17 00:00:00 2001 From: Jai0401 <21cs3025@rgipt.ac.in> Date: Wed, 12 Mar 2025 11:45:32 +0530 Subject: [PATCH 1/3] feat: Add strict mode option to function_schema and function_tool --- src/agents/function_schema.py | 6 +++- src/agents/tool.py | 6 ++++ tests/test_function_tool_decorator.py | 44 +++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index a4b57672..981809e3 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -33,7 +33,10 @@ class FuncSchema: """The signature of the function.""" takes_context: bool = False """Whether the function takes a RunContextWrapper argument (must be the first argument).""" - + strict_json_schema: bool = True + """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, + as it increases the likelihood of correct JSON input.""" + def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: """ Converts validated data from the Pydantic model into (args, kwargs), suitable for calling @@ -337,4 +340,5 @@ def function_schema( params_json_schema=json_schema, signature=sig, takes_context=takes_context, + strict_json_schema=strict_json_schema, ) diff --git a/src/agents/tool.py b/src/agents/tool.py index 75872680..f797e221 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -137,6 +137,7 @@ def function_tool( docstring_style: DocstringStyle | None = None, use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, + strict_mode: bool = True, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -150,6 +151,7 @@ def function_tool( docstring_style: DocstringStyle | None = None, use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = None, + strict_mode: bool = True, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -163,6 +165,7 @@ def function_tool( docstring_style: DocstringStyle | None = None, use_docstring_info: bool = True, failure_error_function: ToolErrorFunction | None = default_tool_error_function, + strict_mode: bool = True, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -186,6 +189,7 @@ def function_tool( failure_error_function: If provided, use this function to generate an error message when the tool call fails. The error message is sent to the LLM. If you pass None, then no error message will be sent and instead an Exception will be raised. + strict_mode: If False, allows optional parameters in the function schema. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: @@ -195,6 +199,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: description_override=description_override, docstring_style=docstring_style, use_docstring_info=use_docstring_info, + strict_json_schema=strict_mode, ) async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str: @@ -273,6 +278,7 @@ async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: description=schema.description or "", params_json_schema=schema.params_json_schema, on_invoke_tool=_on_invoke_tool, + strict_json_schema=strict_mode, ) # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index 3a47deb4..9e8a4322 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -142,3 +142,47 @@ async def test_no_error_on_invalid_json_async(): tool = will_not_fail_on_bad_json_async result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}") assert result == "error_ModelBehaviorError" + + +@function_tool(strict_mode=False) +def optional_param_function(a: int, b: int | None = None) -> str: + if b is None: + return f"{a}_no_b" + return f"{a}_{b}" + + +@pytest.mark.asyncio +async def test_optional_param_function(): + tool = optional_param_function + + input_data = {"a": 5} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "5_no_b" + + input_data = {"a": 5, "b": 10} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "5_10" + + +@function_tool(strict_mode=False) +def multiple_optional_params_function(x: int = 42, y: str = "hello", z: int | None = None) -> str: + if z is None: + return f"{x}_{y}_no_z" + return f"{x}_{y}_{z}" + + +@pytest.mark.asyncio +async def test_multiple_optional_params_function(): + tool = multiple_optional_params_function + + input_data = {} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "42_hello_no_z" + + input_data = {"x": 10, "y": "world"} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "10_world_no_z" + + input_data = {"x": 10, "y": "world", "z": 99} + output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) + assert output == "10_world_99" \ No newline at end of file From a81da6788d5685e164241dd3f01e2db47f314c52 Mon Sep 17 00:00:00 2001 From: Jaimin Godhani <112328542+Jai0401@users.noreply.github.com> Date: Wed, 12 Mar 2025 14:56:19 +0530 Subject: [PATCH 2/3] Update src/agents/tool.py Co-authored-by: Adrian Cole <64215+codefromthecrypt@users.noreply.github.com> --- src/agents/tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agents/tool.py b/src/agents/tool.py index f797e221..c40f2baf 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -189,7 +189,7 @@ def function_tool( failure_error_function: If provided, use this function to generate an error message when the tool call fails. The error message is sent to the LLM. If you pass None, then no error message will be sent and instead an Exception will be raised. - strict_mode: If False, allows optional parameters in the function schema. + strict_mode: If False, parameters with default values become optional in the function schema. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: From 0c33a24d8fe0cc3bf5a0262ce0f79b9007f971ea Mon Sep 17 00:00:00 2001 From: Jai0401 <21cs3025@rgipt.ac.in> Date: Wed, 12 Mar 2025 15:48:50 +0530 Subject: [PATCH 3/3] fix: resolve linting issues --- src/agents/function_schema.py | 2 +- src/agents/tool.py | 3 ++- tests/test_function_tool_decorator.py | 15 ++++++++++----- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index 981809e3..681affce 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -36,7 +36,7 @@ class FuncSchema: strict_json_schema: bool = True """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input.""" - + def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]: """ Converts validated data from the Pydantic model into (args, kwargs), suitable for calling diff --git a/src/agents/tool.py b/src/agents/tool.py index c40f2baf..cbe87944 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -189,7 +189,8 @@ def function_tool( failure_error_function: If provided, use this function to generate an error message when the tool call fails. The error message is sent to the LLM. If you pass None, then no error message will be sent and instead an Exception will be raised. - strict_mode: If False, parameters with default values become optional in the function schema. + strict_mode: If False, parameters with default values become optional in the + function schema. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index 9e8a4322..b5816606 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -1,6 +1,6 @@ import asyncio import json -from typing import Any +from typing import Any, Optional import pytest @@ -145,7 +145,7 @@ async def test_no_error_on_invalid_json_async(): @function_tool(strict_mode=False) -def optional_param_function(a: int, b: int | None = None) -> str: +def optional_param_function(a: int, b: Optional[int] = None) -> str: if b is None: return f"{a}_no_b" return f"{a}_{b}" @@ -165,17 +165,22 @@ async def test_optional_param_function(): @function_tool(strict_mode=False) -def multiple_optional_params_function(x: int = 42, y: str = "hello", z: int | None = None) -> str: +def multiple_optional_params_function( + x: int = 42, + y: str = "hello", + z: Optional[int] = None, +) -> str: if z is None: return f"{x}_{y}_no_z" return f"{x}_{y}_{z}" + @pytest.mark.asyncio async def test_multiple_optional_params_function(): tool = multiple_optional_params_function - input_data = {} + input_data: dict[str,Any] = {} output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) assert output == "42_hello_no_z" @@ -185,4 +190,4 @@ async def test_multiple_optional_params_function(): input_data = {"x": 10, "y": "world", "z": 99} output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) - assert output == "10_world_99" \ No newline at end of file + assert output == "10_world_99"