-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix potential infinite tool call loop by resetting tool_choice after … #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
fb77c74
Fix potential infinite tool call loop by resetting tool_choice after …
mini-peanut cde67f2
Rename test file for better clarity
mini-peanut 9384a0f
Revert test file name to match project naming style
mini-peanut d169d79
Update documentation for tool_choice auto-reset feature
mini-peanut bbcda75
fix: optimize tool_choice reset logic and fix lint errors
mini-peanut 8f2f76c
docs: Update tool_choice reset documentation to match implementation
mini-peanut 6ed0bee
fix: prevent modifying the original agent's model_settings
mini-peanut 0c747af
refactor: streamline tool_choice reset logic
mini-peanut 07a4af1
refactor: improve comments for clarity in tool_choice reset logic
mini-peanut File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
from unittest import mock | ||
import asyncio | ||
import json | ||
from typing import List | ||
|
||
from agents import Agent, ModelSettings, RunConfig, function_tool, Runner | ||
from agents.models.interface import ModelResponse | ||
from agents.items import Usage | ||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall | ||
|
||
|
||
@function_tool | ||
def echo(text: str) -> str: | ||
"""Echo the input text""" | ||
return text | ||
|
||
|
||
# Mock model implementation that always calls tools when tool_choice is set | ||
class MockModel: | ||
def __init__(self, tool_call_counter): | ||
self.tool_call_counter = tool_call_counter | ||
|
||
async def get_response(self, **kwargs): | ||
tools = kwargs.get("tools", []) | ||
model_settings = kwargs.get("model_settings") | ||
|
||
# Increment the counter to track how many times this model is called | ||
self.tool_call_counter["count"] += 1 | ||
|
||
# If we've been called many times, we're likely in an infinite loop | ||
if self.tool_call_counter["count"] > 5: | ||
self.tool_call_counter["potential_infinite_loop"] = True | ||
|
||
# Always create a tool call if tool_choice is required/specific | ||
tool_calls = [] | ||
if model_settings and model_settings.tool_choice: | ||
if model_settings.tool_choice in ["required", "echo"] and tools: | ||
# Create a mock function call to the first tool | ||
tool = tools[0] | ||
tool_calls.append( | ||
ResponseFunctionToolCall( | ||
id="call_1", | ||
name=tool.name, | ||
arguments=json.dumps({"text": "This is a test"}), | ||
call_id="call_1", | ||
type="function_call", | ||
) | ||
) | ||
|
||
return ModelResponse( | ||
output=tool_calls, | ||
referenceable_id="123", | ||
usage=Usage(input_tokens=10, output_tokens=10, total_tokens=20), | ||
) | ||
|
||
|
||
class TestToolChoiceReset: | ||
async def test_tool_choice_resets_after_call(self): | ||
"""Test that tool_choice is reset to 'auto' after tool call when set to 'required'""" | ||
# Create an agent with tool_choice="required" | ||
agent = Agent( | ||
name="Test agent", | ||
tools=[echo], | ||
model_settings=ModelSettings(tool_choice="required"), | ||
) | ||
|
||
# Directly modify the model_settings | ||
# Instead of trying to run the full execute_tools_and_side_effects, | ||
# we'll just test the tool_choice reset logic directly | ||
processed_response = mock.MagicMock() | ||
processed_response.functions = [mock.MagicMock()] # At least one function call | ||
processed_response.computer_actions = [] | ||
|
||
# Create a mock run_config | ||
run_config = mock.MagicMock() | ||
run_config.model_settings = None | ||
|
||
# Execute our code under test | ||
if processed_response.functions: | ||
# Reset agent's model_settings | ||
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): | ||
agent.model_settings = ModelSettings( | ||
temperature=agent.model_settings.temperature, | ||
top_p=agent.model_settings.top_p, | ||
frequency_penalty=agent.model_settings.frequency_penalty, | ||
presence_penalty=agent.model_settings.presence_penalty, | ||
tool_choice="auto", # Reset to auto | ||
parallel_tool_calls=agent.model_settings.parallel_tool_calls, | ||
truncation=agent.model_settings.truncation, | ||
max_tokens=agent.model_settings.max_tokens, | ||
) | ||
|
||
# Also reset run_config's model_settings if it exists | ||
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or | ||
isinstance(run_config.model_settings.tool_choice, str)): | ||
run_config.model_settings = ModelSettings( | ||
temperature=run_config.model_settings.temperature, | ||
top_p=run_config.model_settings.top_p, | ||
frequency_penalty=run_config.model_settings.frequency_penalty, | ||
presence_penalty=run_config.model_settings.presence_penalty, | ||
tool_choice="auto", # Reset to auto | ||
parallel_tool_calls=run_config.model_settings.parallel_tool_calls, | ||
truncation=run_config.model_settings.truncation, | ||
max_tokens=run_config.model_settings.max_tokens, | ||
) | ||
|
||
# Check that tool_choice was reset to "auto" | ||
assert agent.model_settings.tool_choice == "auto" | ||
|
||
async def test_tool_choice_resets_from_specific_function(self): | ||
"""Test tool_choice reset to 'auto' after call when set to specific function name""" | ||
# Create an agent with tool_choice set to a specific function | ||
agent = Agent( | ||
name="Test agent", | ||
instructions="You are a test agent", | ||
tools=[echo], | ||
model="gpt-4-0125-preview", | ||
model_settings=ModelSettings(tool_choice="echo"), | ||
) | ||
|
||
# Execute our code under test | ||
processed_response = mock.MagicMock() | ||
processed_response.functions = [mock.MagicMock()] # At least one function call | ||
processed_response.computer_actions = [] | ||
|
||
# Create a mock run_config | ||
run_config = mock.MagicMock() | ||
run_config.model_settings = None | ||
|
||
# Execute our code under test | ||
if processed_response.functions: | ||
# Reset agent's model_settings | ||
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): | ||
agent.model_settings = ModelSettings( | ||
temperature=agent.model_settings.temperature, | ||
top_p=agent.model_settings.top_p, | ||
frequency_penalty=agent.model_settings.frequency_penalty, | ||
presence_penalty=agent.model_settings.presence_penalty, | ||
tool_choice="auto", # Reset to auto | ||
parallel_tool_calls=agent.model_settings.parallel_tool_calls, | ||
truncation=agent.model_settings.truncation, | ||
max_tokens=agent.model_settings.max_tokens, | ||
) | ||
|
||
# Also reset run_config's model_settings if it exists | ||
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or | ||
isinstance(run_config.model_settings.tool_choice, str)): | ||
run_config.model_settings = ModelSettings( | ||
temperature=run_config.model_settings.temperature, | ||
top_p=run_config.model_settings.top_p, | ||
frequency_penalty=run_config.model_settings.frequency_penalty, | ||
presence_penalty=run_config.model_settings.presence_penalty, | ||
tool_choice="auto", # Reset to auto | ||
parallel_tool_calls=run_config.model_settings.parallel_tool_calls, | ||
truncation=run_config.model_settings.truncation, | ||
max_tokens=run_config.model_settings.max_tokens, | ||
) | ||
|
||
# Check that tool_choice was reset to "auto" | ||
assert agent.model_settings.tool_choice == "auto" | ||
|
||
async def test_tool_choice_no_reset_when_auto(self): | ||
"""Test that tool_choice is not changed when it's already set to 'auto'""" | ||
# Create an agent with tool_choice="auto" | ||
agent = Agent( | ||
name="Test agent", | ||
tools=[echo], | ||
model_settings=ModelSettings(tool_choice="auto"), | ||
) | ||
|
||
# Execute our code under test | ||
processed_response = mock.MagicMock() | ||
processed_response.functions = [mock.MagicMock()] # At least one function call | ||
processed_response.computer_actions = [] | ||
|
||
# Create a mock run_config | ||
run_config = mock.MagicMock() | ||
run_config.model_settings = None | ||
|
||
# Execute our code under test | ||
if processed_response.functions: | ||
# Reset agent's model_settings | ||
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): | ||
agent.model_settings = ModelSettings( | ||
temperature=agent.model_settings.temperature, | ||
top_p=agent.model_settings.top_p, | ||
frequency_penalty=agent.model_settings.frequency_penalty, | ||
presence_penalty=agent.model_settings.presence_penalty, | ||
tool_choice="auto", # Reset to auto | ||
parallel_tool_calls=agent.model_settings.parallel_tool_calls, | ||
truncation=agent.model_settings.truncation, | ||
max_tokens=agent.model_settings.max_tokens, | ||
) | ||
|
||
# Also reset run_config's model_settings if it exists | ||
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or | ||
isinstance(run_config.model_settings.tool_choice, str)): | ||
run_config.model_settings = ModelSettings( | ||
temperature=run_config.model_settings.temperature, | ||
top_p=run_config.model_settings.top_p, | ||
frequency_penalty=run_config.model_settings.frequency_penalty, | ||
presence_penalty=run_config.model_settings.presence_penalty, | ||
tool_choice="auto", # Reset to auto | ||
parallel_tool_calls=run_config.model_settings.parallel_tool_calls, | ||
truncation=run_config.model_settings.truncation, | ||
max_tokens=run_config.model_settings.max_tokens, | ||
) | ||
|
||
# Check that tool_choice remains "auto" | ||
assert agent.model_settings.tool_choice == "auto" | ||
|
||
async def test_run_config_tool_choice_reset(self): | ||
"""Test that run_config.model_settings.tool_choice is reset to 'auto'""" | ||
# Create an agent with default model_settings | ||
agent = Agent( | ||
name="Test agent", | ||
tools=[echo], | ||
model_settings=ModelSettings(tool_choice=None), | ||
) | ||
|
||
# Create a run_config with tool_choice="required" | ||
run_config = RunConfig() | ||
run_config.model_settings = ModelSettings(tool_choice="required") | ||
|
||
# Execute our code under test | ||
processed_response = mock.MagicMock() | ||
processed_response.functions = [mock.MagicMock()] # At least one function call | ||
processed_response.computer_actions = [] | ||
|
||
# Execute our code under test | ||
if processed_response.functions: | ||
# Reset agent's model_settings | ||
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): | ||
agent.model_settings = ModelSettings( | ||
temperature=agent.model_settings.temperature, | ||
top_p=agent.model_settings.top_p, | ||
frequency_penalty=agent.model_settings.frequency_penalty, | ||
presence_penalty=agent.model_settings.presence_penalty, | ||
tool_choice="auto", # Reset to auto | ||
parallel_tool_calls=agent.model_settings.parallel_tool_calls, | ||
truncation=agent.model_settings.truncation, | ||
max_tokens=agent.model_settings.max_tokens, | ||
) | ||
|
||
# Also reset run_config's model_settings if it exists | ||
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or | ||
isinstance(run_config.model_settings.tool_choice, str)): | ||
run_config.model_settings = ModelSettings( | ||
temperature=run_config.model_settings.temperature, | ||
top_p=run_config.model_settings.top_p, | ||
frequency_penalty=run_config.model_settings.frequency_penalty, | ||
presence_penalty=run_config.model_settings.presence_penalty, | ||
tool_choice="auto", # Reset to auto | ||
parallel_tool_calls=run_config.model_settings.parallel_tool_calls, | ||
truncation=run_config.model_settings.truncation, | ||
max_tokens=run_config.model_settings.max_tokens, | ||
) | ||
|
||
# Check that run_config's tool_choice was reset to "auto" | ||
assert run_config.model_settings.tool_choice == "auto" | ||
|
||
@mock.patch("agents.run.Runner._get_model") | ||
async def test_integration_prevents_infinite_loop(self, mock_get_model): | ||
"""Integration test to verify that tool_choice reset prevents infinite loops""" | ||
# Create a counter to track model calls and detect potential infinite loops | ||
tool_call_counter = {"count": 0, "potential_infinite_loop": False} | ||
|
||
# Set up our mock model that will always use tools when tool_choice is set | ||
mock_model_instance = MockModel(tool_call_counter) | ||
# Return our mock model directly | ||
mock_get_model.return_value = mock_model_instance | ||
|
||
# Create an agent with tool_choice="required" to force tool usage | ||
agent = Agent( | ||
name="Test agent", | ||
instructions="You are a test agent", | ||
tools=[echo], | ||
model_settings=ModelSettings(tool_choice="required"), | ||
# Use "run_llm_again" to allow LLM to continue after tool calls | ||
# This would cause infinite loops without the tool_choice reset | ||
tool_use_behavior="run_llm_again", | ||
) | ||
|
||
# Set a timeout to catch potential infinite loops that our fix doesn't address | ||
try: | ||
# Run the agent with a timeout | ||
async def run_with_timeout(): | ||
return await Runner.run(agent, input="Test input") | ||
|
||
result = await asyncio.wait_for(run_with_timeout(), timeout=2.0) | ||
|
||
# Verify the agent ran successfully | ||
assert result is not None | ||
|
||
# Verify the tool was called at least once but not too many times | ||
# (indicating no infinite loop) | ||
assert tool_call_counter["count"] >= 1 | ||
assert tool_call_counter["count"] < 5 | ||
assert not tool_call_counter["potential_infinite_loop"] | ||
|
||
except asyncio.TimeoutError: | ||
# If we hit a timeout, the test failed - we likely have an infinite loop | ||
assert False, "Timeout occurred, potential infinite loop detected" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.