Skip to content

feat: Add strict_mode option to function_schema and function_tool #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,9 @@ class FuncSchema:
"""The signature of the function."""
takes_context: bool = False
"""Whether the function takes a RunContextWrapper argument (must be the first argument)."""
strict_json_schema: bool = True
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""

def to_call_args(self, data: BaseModel) -> tuple[list[Any], dict[str, Any]]:
"""
@@ -337,4 +340,5 @@ def function_schema(
params_json_schema=json_schema,
signature=sig,
takes_context=takes_context,
strict_json_schema=strict_json_schema,
)
7 changes: 7 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
@@ -137,6 +137,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
@@ -150,6 +151,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
@@ -163,6 +165,7 @@ def function_tool(
docstring_style: DocstringStyle | None = None,
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
@@ -186,6 +189,8 @@ def function_tool(
failure_error_function: If provided, use this function to generate an error message when
the tool call fails. The error message is sent to the LLM. If you pass None, then no
error message will be sent and instead an Exception will be raised.
strict_mode: If False, parameters with default values become optional in the
function schema.
"""

def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@@ -195,6 +200,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
description_override=description_override,
docstring_style=docstring_style,
use_docstring_info=use_docstring_info,
strict_json_schema=strict_mode,
)

async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
@@ -273,6 +279,7 @@ async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
description=schema.description or "",
params_json_schema=schema.params_json_schema,
on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode,
)

# If func is actually a callable, we were used as @function_tool with no parentheses
51 changes: 50 additions & 1 deletion tests/test_function_tool_decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import json
from typing import Any
from typing import Any, Optional

import pytest

@@ -142,3 +142,52 @@ async def test_no_error_on_invalid_json_async():
tool = will_not_fail_on_bad_json_async
result = await tool.on_invoke_tool(ctx_wrapper(), "{not valid json}")
assert result == "error_ModelBehaviorError"


@function_tool(strict_mode=False)
def optional_param_function(a: int, b: Optional[int] = None) -> str:
if b is None:
return f"{a}_no_b"
return f"{a}_{b}"


@pytest.mark.asyncio
async def test_optional_param_function():
tool = optional_param_function

input_data = {"a": 5}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_no_b"

input_data = {"a": 5, "b": 10}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "5_10"


@function_tool(strict_mode=False)
def multiple_optional_params_function(
x: int = 42,
y: str = "hello",
z: Optional[int] = None,
) -> str:
if z is None:
return f"{x}_{y}_no_z"
return f"{x}_{y}_{z}"



@pytest.mark.asyncio
async def test_multiple_optional_params_function():
tool = multiple_optional_params_function

input_data: dict[str,Any] = {}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "42_hello_no_z"

input_data = {"x": 10, "y": "world"}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_no_z"

input_data = {"x": 10, "y": "world", "z": 99}
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
assert output == "10_world_99"