Skip to content

Commit 0e57bd9

Browse files
committed
Make the reset behavior on tool use configurable
## Summary: #263 added this behavior. The goal was to prevent infinite loops when tool choice was set. The key change I'm making is: 1. Making it configurable on the agent. 2. Doing bookkeeping in the Runner to track this, to prevent mutating agents. 3. Not resetting the global tool choice in RunConfig. ## Test Plan: Unit tests. .
1 parent 923a354 commit 0e57bd9

File tree

7 files changed

+173
-98
lines changed

7 files changed

+173
-98
lines changed

Diff for: docs/agents.md

+2-7
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,6 @@ Supplying a list of tools doesn't always mean the LLM will use a tool. You can f
142142

143143
!!! note
144144

145-
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call in the following scenarios:
146-
147-
1. When `tool_choice` is set to a specific function name (any string that's not "auto", "required", or "none")
148-
2. When `tool_choice` is set to "required" AND there is only one tool available
149-
150-
This targeted reset mechanism allows the model to decide whether to make additional tool calls in subsequent turns while avoiding infinite loops in these specific cases.
151-
145+
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call. This behavior is configurable via [`agent.reset_tool_choice`][agents.agent.Agent.reset_tool_choice]. The infinite loop is because tool results are sent to the LLM, which then generates another tool call because of `tool_choice`, ad infinitum.
146+
152147
If you want the Agent to completely stop after a tool call (rather than continuing with auto mode), you can set [`Agent.tool_use_behavior="stop_on_first_tool"`] which will directly use the tool output as the final response without further LLM processing.

Diff for: src/agents/_run_impl.py

+34-40
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import dataclasses
55
import inspect
66
from collections.abc import Awaitable
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from typing import TYPE_CHECKING, Any, cast
99

1010
from openai.types.responses import (
@@ -52,7 +52,7 @@
5252
from .models.interface import ModelTracing
5353
from .run_context import RunContextWrapper, TContext
5454
from .stream_events import RunItemStreamEvent, StreamEvent
55-
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
55+
from .tool import ComputerTool, FunctionTool, FunctionToolResult
5656
from .tracing import (
5757
SpanError,
5858
Trace,
@@ -77,6 +77,22 @@ class QueueCompleteSentinel:
7777
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
7878

7979

80+
@dataclass
81+
class AgentToolUseTracker:
82+
data: list[tuple[Agent, list[str]]] = field(default_factory=list)
83+
84+
def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None:
85+
existing_data = next((item for item in self.data if item[0] == agent), None)
86+
if existing_data:
87+
existing_data[1].extend(tool_names)
88+
else:
89+
self.data.append((agent, tool_names))
90+
91+
def has_used_tools(self, agent: Agent[Any]) -> bool:
92+
existing_data = next((item for item in self.data if item[0] == agent), None)
93+
return existing_data is not None and len(existing_data[1]) > 0
94+
95+
8096
@dataclass
8197
class ToolRunHandoff:
8298
handoff: Handoff
@@ -101,6 +117,7 @@ class ProcessedResponse:
101117
handoffs: list[ToolRunHandoff]
102118
functions: list[ToolRunFunction]
103119
computer_actions: list[ToolRunComputerAction]
120+
tools_used: list[str] # Names of all tools used, including hosted tools
104121

105122
def has_tools_to_run(self) -> bool:
106123
# Handoffs, functions and computer actions need local processing
@@ -208,29 +225,6 @@ async def execute_tools_and_side_effects(
208225
new_step_items.extend([result.run_item for result in function_results])
209226
new_step_items.extend(computer_results)
210227

211-
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
212-
if processed_response.functions or processed_response.computer_actions:
213-
tools = agent.tools
214-
215-
if (
216-
run_config.model_settings and
217-
cls._should_reset_tool_choice(run_config.model_settings, tools)
218-
):
219-
# update the run_config model settings with a copy
220-
new_run_config_settings = dataclasses.replace(
221-
run_config.model_settings,
222-
tool_choice="auto"
223-
)
224-
run_config = dataclasses.replace(run_config, model_settings=new_run_config_settings)
225-
226-
if cls._should_reset_tool_choice(agent.model_settings, tools):
227-
# Create a modified copy instead of modifying the original agent
228-
new_model_settings = dataclasses.replace(
229-
agent.model_settings,
230-
tool_choice="auto"
231-
)
232-
agent = dataclasses.replace(agent, model_settings=new_model_settings)
233-
234228
# Second, check if there are any handoffs
235229
if run_handoffs := processed_response.handoffs:
236230
return await cls.execute_handoffs(
@@ -322,22 +316,16 @@ async def execute_tools_and_side_effects(
322316
)
323317

324318
@classmethod
325-
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
326-
if model_settings is None or model_settings.tool_choice is None:
327-
return False
319+
def maybe_reset_tool_choice(
320+
cls, agent: Agent[Any], tool_use_tracker: AgentToolUseTracker, model_settings: ModelSettings
321+
) -> ModelSettings:
322+
"""Resets tool choice to None if the agent has used tools and the agent's reset_tool_choice
323+
flag is True."""
328324

329-
# for specific tool choices
330-
if (
331-
isinstance(model_settings.tool_choice, str) and
332-
model_settings.tool_choice not in ["auto", "required", "none"]
333-
):
334-
return True
325+
if agent.reset_tool_choice is True and tool_use_tracker.has_used_tools(agent):
326+
return dataclasses.replace(model_settings, tool_choice=None)
335327

336-
# for one tool and required tool choice
337-
if model_settings.tool_choice == "required":
338-
return len(tools) == 1
339-
340-
return False
328+
return model_settings
341329

342330
@classmethod
343331
def process_model_response(
@@ -353,7 +341,7 @@ def process_model_response(
353341
run_handoffs = []
354342
functions = []
355343
computer_actions = []
356-
344+
tools_used: list[str] = []
357345
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
358346
function_map = {tool.name: tool for tool in agent.tools if isinstance(tool, FunctionTool)}
359347
computer_tool = next((tool for tool in agent.tools if isinstance(tool, ComputerTool)), None)
@@ -363,12 +351,15 @@ def process_model_response(
363351
items.append(MessageOutputItem(raw_item=output, agent=agent))
364352
elif isinstance(output, ResponseFileSearchToolCall):
365353
items.append(ToolCallItem(raw_item=output, agent=agent))
354+
tools_used.append("file_search")
366355
elif isinstance(output, ResponseFunctionWebSearch):
367356
items.append(ToolCallItem(raw_item=output, agent=agent))
357+
tools_used.append("web_search")
368358
elif isinstance(output, ResponseReasoningItem):
369359
items.append(ReasoningItem(raw_item=output, agent=agent))
370360
elif isinstance(output, ResponseComputerToolCall):
371361
items.append(ToolCallItem(raw_item=output, agent=agent))
362+
tools_used.append("computer_use")
372363
if not computer_tool:
373364
_error_tracing.attach_error_to_current_span(
374365
SpanError(
@@ -390,6 +381,8 @@ def process_model_response(
390381
if not isinstance(output, ResponseFunctionToolCall):
391382
continue
392383

384+
tools_used.append(output.name)
385+
393386
# Handoffs
394387
if output.name in handoff_map:
395388
items.append(HandoffCallItem(raw_item=output, agent=agent))
@@ -421,6 +414,7 @@ def process_model_response(
421414
handoffs=run_handoffs,
422415
functions=functions,
423416
computer_actions=computer_actions,
417+
tools_used=tools_used,
424418
)
425419

426420
@classmethod

Diff for: src/agents/agent.py

+4
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ class Agent(Generic[TContext]):
143143
web search, etc are always processed by the LLM.
144144
"""
145145

146+
reset_tool_choice: bool = True
147+
"""Whether to reset the tool choice to the default value after a tool has been called. Defaults
148+
to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""
149+
146150
def clone(self, **kwargs: Any) -> Agent[TContext]:
147151
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
148152
```

Diff for: src/agents/models/openai_responses.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,10 @@ async def _fetch_response(
208208
list_input = ItemHelpers.input_to_new_input_list(input)
209209

210210
parallel_tool_calls = (
211-
True if model_settings.parallel_tool_calls and tools and len(tools) > 0
212-
else False if model_settings.parallel_tool_calls is False
211+
True
212+
if model_settings.parallel_tool_calls and tools and len(tools) > 0
213+
else False
214+
if model_settings.parallel_tool_calls is False
213215
else NOT_GIVEN
214216
)
215217

Diff for: src/agents/run.py

+21
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from openai.types.responses import ResponseCompletedEvent
99

1010
from ._run_impl import (
11+
AgentToolUseTracker,
1112
NextStepFinalOutput,
1213
NextStepHandoff,
1314
NextStepRunAgain,
@@ -149,6 +150,8 @@ async def run(
149150
if run_config is None:
150151
run_config = RunConfig()
151152

153+
tool_use_tracker = AgentToolUseTracker()
154+
152155
with TraceCtxManager(
153156
workflow_name=run_config.workflow_name,
154157
trace_id=run_config.trace_id,
@@ -223,6 +226,7 @@ async def run(
223226
context_wrapper=context_wrapper,
224227
run_config=run_config,
225228
should_run_agent_start_hooks=should_run_agent_start_hooks,
229+
tool_use_tracker=tool_use_tracker,
226230
),
227231
)
228232
else:
@@ -234,6 +238,7 @@ async def run(
234238
context_wrapper=context_wrapper,
235239
run_config=run_config,
236240
should_run_agent_start_hooks=should_run_agent_start_hooks,
241+
tool_use_tracker=tool_use_tracker,
237242
)
238243
should_run_agent_start_hooks = False
239244

@@ -481,6 +486,7 @@ async def _run_streamed_impl(
481486
current_agent = starting_agent
482487
current_turn = 0
483488
should_run_agent_start_hooks = True
489+
tool_use_tracker = AgentToolUseTracker()
484490

485491
streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent))
486492

@@ -541,6 +547,7 @@ async def _run_streamed_impl(
541547
context_wrapper,
542548
run_config,
543549
should_run_agent_start_hooks,
550+
tool_use_tracker,
544551
)
545552
should_run_agent_start_hooks = False
546553

@@ -608,6 +615,7 @@ async def _run_single_turn_streamed(
608615
context_wrapper: RunContextWrapper[TContext],
609616
run_config: RunConfig,
610617
should_run_agent_start_hooks: bool,
618+
tool_use_tracker: AgentToolUseTracker,
611619
) -> SingleStepResult:
612620
if should_run_agent_start_hooks:
613621
await asyncio.gather(
@@ -630,6 +638,8 @@ async def _run_single_turn_streamed(
630638

631639
model = cls._get_model(agent, run_config)
632640
model_settings = agent.model_settings.resolve(run_config.model_settings)
641+
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
642+
633643
final_response: ModelResponse | None = None
634644

635645
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
@@ -681,6 +691,7 @@ async def _run_single_turn_streamed(
681691
hooks=hooks,
682692
context_wrapper=context_wrapper,
683693
run_config=run_config,
694+
tool_use_tracker=tool_use_tracker,
684695
)
685696

686697
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
@@ -697,6 +708,7 @@ async def _run_single_turn(
697708
context_wrapper: RunContextWrapper[TContext],
698709
run_config: RunConfig,
699710
should_run_agent_start_hooks: bool,
711+
tool_use_tracker: AgentToolUseTracker,
700712
) -> SingleStepResult:
701713
# Ensure we run the hooks before anything else
702714
if should_run_agent_start_hooks:
@@ -724,6 +736,7 @@ async def _run_single_turn(
724736
handoffs,
725737
context_wrapper,
726738
run_config,
739+
tool_use_tracker,
727740
)
728741

729742
return await cls._get_single_step_result_from_response(
@@ -736,6 +749,7 @@ async def _run_single_turn(
736749
hooks=hooks,
737750
context_wrapper=context_wrapper,
738751
run_config=run_config,
752+
tool_use_tracker=tool_use_tracker,
739753
)
740754

741755
@classmethod
@@ -751,13 +765,17 @@ async def _get_single_step_result_from_response(
751765
hooks: RunHooks[TContext],
752766
context_wrapper: RunContextWrapper[TContext],
753767
run_config: RunConfig,
768+
tool_use_tracker: AgentToolUseTracker,
754769
) -> SingleStepResult:
755770
processed_response = RunImpl.process_model_response(
756771
agent=agent,
757772
response=new_response,
758773
output_schema=output_schema,
759774
handoffs=handoffs,
760775
)
776+
777+
tool_use_tracker.add_tool_use(agent, processed_response.tools_used)
778+
761779
return await RunImpl.execute_tools_and_side_effects(
762780
agent=agent,
763781
original_input=original_input,
@@ -856,9 +874,12 @@ async def _get_new_response(
856874
handoffs: list[Handoff],
857875
context_wrapper: RunContextWrapper[TContext],
858876
run_config: RunConfig,
877+
tool_use_tracker: AgentToolUseTracker,
859878
) -> ModelResponse:
860879
model = cls._get_model(agent, run_config)
861880
model_settings = agent.model_settings.resolve(run_config.model_settings)
881+
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
882+
862883
new_response = await model.get_response(
863884
system_instructions=system_prompt,
864885
input=input,

Diff for: tests/fake_model.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import AsyncIterator
4+
from typing import Any
45

56
from openai.types.responses import Response, ResponseCompletedEvent
67

@@ -31,6 +32,7 @@ def __init__(
3132
[initial_output] if initial_output else []
3233
)
3334
self.tracing_enabled = tracing_enabled
35+
self.last_turn_args: dict[str, Any] = {}
3436

3537
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
3638
self.turn_outputs.append(output)
@@ -53,6 +55,14 @@ async def get_response(
5355
handoffs: list[Handoff],
5456
tracing: ModelTracing,
5557
) -> ModelResponse:
58+
self.last_turn_args = {
59+
"system_instructions": system_instructions,
60+
"input": input,
61+
"model_settings": model_settings,
62+
"tools": tools,
63+
"output_schema": output_schema,
64+
}
65+
5666
with generation_span(disabled=not self.tracing_enabled) as span:
5767
output = self.get_next_output()
5868

0 commit comments

Comments
 (0)