Skip to content

Commit 951193b

Browse files
authored
feat: Add strict_mode option to function_schema and function_tool (#60)
This PR introduces a `strict_mode: bool = True` option to `@function_tool`, allowing optional parameters when set to False. This change enables more flexibility while maintaining strict JSON schema validation by default. resolves #43 ## Changes: - Added `strict_mode` parameter to `@function_tool` and passed it to `function_schema` and `FunctionTool`. - Updated `function_schema.py` to respect `strict_mode` and allow optional parameters when set to False. - Added unit tests to verify optional parameters work correctly, including multiple optional params with different types. ## Tests: - Verified function calls with missing optional parameters behave as expected. - Added async tests to validate behavior under different configurations.
2 parents cdbf6b0 + 0c33a24 commit 951193b

File tree

3 files changed

+61
-1
lines changed

3 files changed

+61
-1
lines changed

src/agents/function_schema.py

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ class FuncSchema:
3333
"""The signature of the function."""
3434
takes_context: bool = False
3535
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
36+
strict_json_schema: bool = True
37+
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
38+
as it increases the likelihood of correct JSON input."""
3639

3740
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
3841
"""
@@ -337,4 +340,5 @@ def function_schema(
337340
params_json_schema=json_schema,
338341
signature=sig,
339342
takes_context=takes_context,
343+
strict_json_schema=strict_json_schema,
340344
)

src/agents/tool.py

+7
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def function_tool(
137137
docstring_style: DocstringStyle | None = None,
138138
use_docstring_info: bool = True,
139139
failure_error_function: ToolErrorFunction | None = None,
140+
strict_mode: bool = True,
140141
) -> FunctionTool:
141142
"""Overload for usage as @function_tool (no parentheses)."""
142143
...
@@ -150,6 +151,7 @@ def function_tool(
150151
docstring_style: DocstringStyle | None = None,
151152
use_docstring_info: bool = True,
152153
failure_error_function: ToolErrorFunction | None = None,
154+
strict_mode: bool = True,
153155
) -> Callable[[ToolFunction[...]], FunctionTool]:
154156
"""Overload for usage as @function_tool(...)."""
155157
...
@@ -163,6 +165,7 @@ def function_tool(
163165
docstring_style: DocstringStyle | None = None,
164166
use_docstring_info: bool = True,
165167
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
168+
strict_mode: bool = True,
166169
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
167170
"""
168171
Decorator to create a FunctionTool from a function. By default, we will:
@@ -186,6 +189,8 @@ def function_tool(
186189
failure_error_function: If provided, use this function to generate an error message when
187190
the tool call fails. The error message is sent to the LLM. If you pass None, then no
188191
error message will be sent and instead an Exception will be raised.
192+
strict_mode: If False, parameters with default values become optional in the
193+
function schema.
189194
"""
190195

191196
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@@ -195,6 +200,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
195200
description_override=description_override,
196201
docstring_style=docstring_style,
197202
use_docstring_info=use_docstring_info,
203+
strict_json_schema=strict_mode,
198204
)
199205

200206
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
@@ -273,6 +279,7 @@ async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
273279
description=schema.description or "",
274280
params_json_schema=schema.params_json_schema,
275281
on_invoke_tool=_on_invoke_tool,
282+
strict_json_schema=strict_mode,
276283
)
277284

278285
# If func is actually a callable, we were used as @function_tool with no parentheses

tests/test_function_tool_decorator.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import json
3-
from typing import Any
3+
from typing import Any, Optional
44

55
import pytest
66

@@ -142,3 +142,52 @@ async def test_no_error_on_invalid_json_async():
142142
tool = will_not_fail_on_bad_json_async
143143
result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}")
144144
assert result == "error_ModelBehaviorError"
145+
146+
147+
@function_tool(strict_mode=False)
148+
def optional_param_function(a: int, b: Optional[int] = None) -> str:
149+
if b is None:
150+
return f"{a}_no_b"
151+
return f"{a}_{b}"
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_optional_param_function():
156+
tool = optional_param_function
157+
158+
input_data = {"a": 5}
159+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
160+
assert output == "5_no_b"
161+
162+
input_data = {"a": 5, "b": 10}
163+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
164+
assert output == "5_10"
165+
166+
167+
@function_tool(strict_mode=False)
168+
def multiple_optional_params_function(
169+
x: int = 42,
170+
y: str = "hello",
171+
z: Optional[int] = None,
172+
) -> str:
173+
if z is None:
174+
return f"{x}_{y}_no_z"
175+
return f"{x}_{y}_{z}"
176+
177+
178+
179+
@pytest.mark.asyncio
180+
async def test_multiple_optional_params_function():
181+
tool = multiple_optional_params_function
182+
183+
input_data: dict[str,Any] = {}
184+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
185+
assert output == "42_hello_no_z"
186+
187+
input_data = {"x": 10, "y": "world"}
188+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
189+
assert output == "10_world_no_z"
190+
191+
input_data = {"x": 10, "y": "world", "z": 99}
192+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
193+
assert output == "10_world_99"

0 commit comments

Comments
 (0)