Skip to content

Commit b7e7fde

Browse files
committed
feat: Add strict mode option to function_schema and function_tool
1 parent 9a32ff5 commit b7e7fde

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

Diff for: src/agents/function_schema.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ class FuncSchema:
3333
"""The signature of the function."""
3434
takes_context: bool = False
3535
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
36-
36+
strict_json_schema: bool = True
37+
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
38+
as it increases the likelihood of correct JSON input."""
39+
3740
def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
3841
"""
3942
Converts validated data from the Pydantic model into (args, kwargs), suitable for calling
@@ -337,4 +340,5 @@ def function_schema(
337340
params_json_schema=json_schema,
338341
signature=sig,
339342
takes_context=takes_context,
343+
strict_json_schema=strict_json_schema,
340344
)

Diff for: src/agents/tool.py

+6
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def function_tool(
137137
docstring_style: DocstringStyle | None = None,
138138
use_docstring_info: bool = True,
139139
failure_error_function: ToolErrorFunction | None = None,
140+
strict_mode: bool = True,
140141
) -> FunctionTool:
141142
"""Overload for usage as @function_tool (no parentheses)."""
142143
...
@@ -150,6 +151,7 @@ def function_tool(
150151
docstring_style: DocstringStyle | None = None,
151152
use_docstring_info: bool = True,
152153
failure_error_function: ToolErrorFunction | None = None,
154+
strict_mode: bool = True,
153155
) -> Callable[[ToolFunction[...]], FunctionTool]:
154156
"""Overload for usage as @function_tool(...)."""
155157
...
@@ -163,6 +165,7 @@ def function_tool(
163165
docstring_style: DocstringStyle | None = None,
164166
use_docstring_info: bool = True,
165167
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
168+
strict_mode: bool = True,
166169
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
167170
"""
168171
Decorator to create a FunctionTool from a function. By default, we will:
@@ -186,6 +189,7 @@ def function_tool(
186189
failure_error_function: If provided, use this function to generate an error message when
187190
the tool call fails. The error message is sent to the LLM. If you pass None, then no
188191
error message will be sent and instead an Exception will be raised.
192+
strict_mode: If False, allows optional parameters in the function schema.
189193
"""
190194

191195
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@@ -195,6 +199,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
195199
description_override=description_override,
196200
docstring_style=docstring_style,
197201
use_docstring_info=use_docstring_info,
202+
strict_json_schema=strict_mode,
198203
)
199204

200205
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
@@ -273,6 +278,7 @@ async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
273278
description=schema.description or "",
274279
params_json_schema=schema.params_json_schema,
275280
on_invoke_tool=_on_invoke_tool,
281+
strict_json_schema=strict_mode,
276282
)
277283

278284
# If func is actually a callable, we were used as @function_tool with no parentheses

Diff for: tests/test_function_tool_decorator.py

+44
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,47 @@ async def test_no_error_on_invalid_json_async():
142142
tool = will_not_fail_on_bad_json_async
143143
result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}")
144144
assert result == "error_ModelBehaviorError"
145+
146+
147+
@function_tool(strict_mode=False)
148+
def optional_param_function(a: int, b: int | None = None) -> str:
149+
if b is None:
150+
return f"{a}_no_b"
151+
return f"{a}_{b}"
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_optional_param_function():
156+
tool = optional_param_function
157+
158+
input_data = {"a": 5}
159+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
160+
assert output == "5_no_b"
161+
162+
input_data = {"a": 5, "b": 10}
163+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
164+
assert output == "5_10"
165+
166+
167+
@function_tool(strict_mode=False)
168+
def multiple_optional_params_function(x: int = 42, y: str = "hello", z: int | None = None) -> str:
169+
if z is None:
170+
return f"{x}_{y}_no_z"
171+
return f"{x}_{y}_{z}"
172+
173+
174+
@pytest.mark.asyncio
175+
async def test_multiple_optional_params_function():
176+
tool = multiple_optional_params_function
177+
178+
input_data = {}
179+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
180+
assert output == "42_hello_no_z"
181+
182+
input_data = {"x": 10, "y": "world"}
183+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
184+
assert output == "10_world_no_z"
185+
186+
input_data = {"x": 10, "y": "world", "z": 99}
187+
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
188+
assert output == "10_world_99"

0 commit comments

Comments
 (0)