Skip to content

Commit ad020b7

Browse files
authored
Make the reset behavior on tool use configurable (#335)
## Summary: #263 added this behavior. The goal was to prevent infinite loops when tool choice was set. The key change I'm making is: 1. Making it configurable on the agent. 2. Doing bookkeeping in the Runner to track this, to prevent mutating agents. 3. Not resetting the global tool choice in RunConfig. ## Test Plan: Unit tests. .
2 parents 362a9dc + 6fb5792 commit ad020b7

File tree

7 files changed

+173
-97
lines changed

7 files changed

+173
-97
lines changed

docs/agents.md

+2-7
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,6 @@ Supplying a list of tools doesn't always mean the LLM will use a tool. You can f
142142

143143
!!! note
144144

145-
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call in the following scenarios:
146-
147-
1. When `tool_choice` is set to a specific function name (any string that's not "auto", "required", or "none")
148-
2. When `tool_choice` is set to "required" AND there is only one tool available
149-
150-
This targeted reset mechanism allows the model to decide whether to make additional tool calls in subsequent turns while avoiding infinite loops in these specific cases.
151-
145+
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call. This behavior is configurable via [`agent.reset_tool_choice`][agents.agent.Agent.reset_tool_choice]. The infinite loop is because tool results are sent to the LLM, which then generates another tool call because of `tool_choice`, ad infinitum.
146+
152147
If you want the Agent to completely stop after a tool call (rather than continuing with auto mode), you can set [`Agent.tool_use_behavior="stop_on_first_tool"`] which will directly use the tool output as the final response without further LLM processing.

src/agents/_run_impl.py

+34-39
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import dataclasses
55
import inspect
66
from collections.abc import Awaitable
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from typing import TYPE_CHECKING, Any, cast
99

1010
from openai.types.responses import (
@@ -77,6 +77,23 @@ class QueueCompleteSentinel:
7777
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
7878

7979

80+
@dataclass
81+
class AgentToolUseTracker:
82+
agent_to_tools: list[tuple[Agent, list[str]]] = field(default_factory=list)
83+
"""Tuple of (agent, list of tools used). Can't use a dict because agents aren't hashable."""
84+
85+
def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None:
86+
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
87+
if existing_data:
88+
existing_data[1].extend(tool_names)
89+
else:
90+
self.agent_to_tools.append((agent, tool_names))
91+
92+
def has_used_tools(self, agent: Agent[Any]) -> bool:
93+
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
94+
return existing_data is not None and len(existing_data[1]) > 0
95+
96+
8097
@dataclass
8198
class ToolRunHandoff:
8299
handoff: Handoff
@@ -101,6 +118,7 @@ class ProcessedResponse:
101118
handoffs: list[ToolRunHandoff]
102119
functions: list[ToolRunFunction]
103120
computer_actions: list[ToolRunComputerAction]
121+
tools_used: list[str] # Names of all tools used, including hosted tools
104122

105123
def has_tools_to_run(self) -> bool:
106124
# Handoffs, functions and computer actions need local processing
@@ -208,29 +226,6 @@ async def execute_tools_and_side_effects(
208226
new_step_items.extend([result.run_item for result in function_results])
209227
new_step_items.extend(computer_results)
210228

211-
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
212-
if processed_response.functions or processed_response.computer_actions:
213-
tools = agent.tools
214-
215-
if (
216-
run_config.model_settings and
217-
cls._should_reset_tool_choice(run_config.model_settings, tools)
218-
):
219-
# update the run_config model settings with a copy
220-
new_run_config_settings = dataclasses.replace(
221-
run_config.model_settings,
222-
tool_choice="auto"
223-
)
224-
run_config = dataclasses.replace(run_config, model_settings=new_run_config_settings)
225-
226-
if cls._should_reset_tool_choice(agent.model_settings, tools):
227-
# Create a modified copy instead of modifying the original agent
228-
new_model_settings = dataclasses.replace(
229-
agent.model_settings,
230-
tool_choice="auto"
231-
)
232-
agent = dataclasses.replace(agent, model_settings=new_model_settings)
233-
234229
# Second, check if there are any handoffs
235230
if run_handoffs := processed_response.handoffs:
236231
return await cls.execute_handoffs(
@@ -322,22 +317,16 @@ async def execute_tools_and_side_effects(
322317
)
323318

324319
@classmethod
325-
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
326-
if model_settings is None or model_settings.tool_choice is None:
327-
return False
320+
def maybe_reset_tool_choice(
321+
cls, agent: Agent[Any], tool_use_tracker: AgentToolUseTracker, model_settings: ModelSettings
322+
) -> ModelSettings:
323+
"""Resets tool choice to None if the agent has used tools and the agent's reset_tool_choice
324+
flag is True."""
328325

329-
# for specific tool choices
330-
if (
331-
isinstance(model_settings.tool_choice, str) and
332-
model_settings.tool_choice not in ["auto", "required", "none"]
333-
):
334-
return True
326+
if agent.reset_tool_choice is True and tool_use_tracker.has_used_tools(agent):
327+
return dataclasses.replace(model_settings, tool_choice=None)
335328

336-
# for one tool and required tool choice
337-
if model_settings.tool_choice == "required":
338-
return len(tools) == 1
339-
340-
return False
329+
return model_settings
341330

342331
@classmethod
343332
def process_model_response(
@@ -354,7 +343,7 @@ def process_model_response(
354343
run_handoffs = []
355344
functions = []
356345
computer_actions = []
357-
346+
tools_used: list[str] = []
358347
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
359348
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
360349
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
@@ -364,12 +353,15 @@ def process_model_response(
364353
items.append(MessageOutputItem(raw_item=output, agent=agent))
365354
elif isinstance(output, ResponseFileSearchToolCall):
366355
items.append(ToolCallItem(raw_item=output, agent=agent))
356+
tools_used.append("file_search")
367357
elif isinstance(output, ResponseFunctionWebSearch):
368358
items.append(ToolCallItem(raw_item=output, agent=agent))
359+
tools_used.append("web_search")
369360
elif isinstance(output, ResponseReasoningItem):
370361
items.append(ReasoningItem(raw_item=output, agent=agent))
371362
elif isinstance(output, ResponseComputerToolCall):
372363
items.append(ToolCallItem(raw_item=output, agent=agent))
364+
tools_used.append("computer_use")
373365
if not computer_tool:
374366
_error_tracing.attach_error_to_current_span(
375367
SpanError(
@@ -391,6 +383,8 @@ def process_model_response(
391383
if not isinstance(output, ResponseFunctionToolCall):
392384
continue
393385

386+
tools_used.append(output.name)
387+
394388
# Handoffs
395389
if output.name in handoff_map:
396390
items.append(HandoffCallItem(raw_item=output, agent=agent))
@@ -422,6 +416,7 @@ def process_model_response(
422416
handoffs=run_handoffs,
423417
functions=functions,
424418
computer_actions=computer_actions,
419+
tools_used=tools_used,
425420
)
426421

427422
@classmethod

src/agents/agent.py

+4
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ class Agent(Generic[TContext]):
155155
web search, etc are always processed by the LLM.
156156
"""
157157

158+
reset_tool_choice: bool = True
159+
"""Whether to reset the tool choice to the default value after a tool has been called. Defaults
160+
to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""
161+
158162
def clone(self, **kwargs: Any) -> Agent[TContext]:
159163
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
160164
```

src/agents/models/openai_responses.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,10 @@ async def _fetch_response(
208208
list_input = ItemHelpers.input_to_new_input_list(input)
209209

210210
parallel_tool_calls = (
211-
True if model_settings.parallel_tool_calls and tools and len(tools) > 0
212-
else False if model_settings.parallel_tool_calls is False
211+
True
212+
if model_settings.parallel_tool_calls and tools and len(tools) > 0
213+
else False
214+
if model_settings.parallel_tool_calls is False
213215
else NOT_GIVEN
214216
)
215217

src/agents/run.py

+21
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from agents.tool import Tool
1111

1212
from ._run_impl import (
13+
AgentToolUseTracker,
1314
NextStepFinalOutput,
1415
NextStepHandoff,
1516
NextStepRunAgain,
@@ -151,6 +152,8 @@ async def run(
151152
if run_config is None:
152153
run_config = RunConfig()
153154

155+
tool_use_tracker = AgentToolUseTracker()
156+
154157
with TraceCtxManager(
155158
workflow_name=run_config.workflow_name,
156159
trace_id=run_config.trace_id,
@@ -227,6 +230,7 @@ async def run(
227230
context_wrapper=context_wrapper,
228231
run_config=run_config,
229232
should_run_agent_start_hooks=should_run_agent_start_hooks,
233+
tool_use_tracker=tool_use_tracker,
230234
),
231235
)
232236
else:
@@ -239,6 +243,7 @@ async def run(
239243
context_wrapper=context_wrapper,
240244
run_config=run_config,
241245
should_run_agent_start_hooks=should_run_agent_start_hooks,
246+
tool_use_tracker=tool_use_tracker,
242247
)
243248
should_run_agent_start_hooks = False
244249

@@ -486,6 +491,7 @@ async def _run_streamed_impl(
486491
current_agent = starting_agent
487492
current_turn = 0
488493
should_run_agent_start_hooks = True
494+
tool_use_tracker = AgentToolUseTracker()
489495

490496
streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent))
491497

@@ -546,6 +552,7 @@ async def _run_streamed_impl(
546552
context_wrapper,
547553
run_config,
548554
should_run_agent_start_hooks,
555+
tool_use_tracker,
549556
)
550557
should_run_agent_start_hooks = False
551558

@@ -613,6 +620,7 @@ async def _run_single_turn_streamed(
613620
context_wrapper: RunContextWrapper[TContext],
614621
run_config: RunConfig,
615622
should_run_agent_start_hooks: bool,
623+
tool_use_tracker: AgentToolUseTracker,
616624
) -> SingleStepResult:
617625
if should_run_agent_start_hooks:
618626
await asyncio.gather(
@@ -635,6 +643,8 @@ async def _run_single_turn_streamed(
635643
all_tools = await cls._get_all_tools(agent)
636644
model = cls._get_model(agent, run_config)
637645
model_settings = agent.model_settings.resolve(run_config.model_settings)
646+
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
647+
638648
final_response: ModelResponse | None = None
639649

640650
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
@@ -687,6 +697,7 @@ async def _run_single_turn_streamed(
687697
hooks=hooks,
688698
context_wrapper=context_wrapper,
689699
run_config=run_config,
700+
tool_use_tracker=tool_use_tracker,
690701
)
691702

692703
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
@@ -704,6 +715,7 @@ async def _run_single_turn(
704715
context_wrapper: RunContextWrapper[TContext],
705716
run_config: RunConfig,
706717
should_run_agent_start_hooks: bool,
718+
tool_use_tracker: AgentToolUseTracker,
707719
) -> SingleStepResult:
708720
# Ensure we run the hooks before anything else
709721
if should_run_agent_start_hooks:
@@ -732,6 +744,7 @@ async def _run_single_turn(
732744
handoffs,
733745
context_wrapper,
734746
run_config,
747+
tool_use_tracker,
735748
)
736749

737750
return await cls._get_single_step_result_from_response(
@@ -745,6 +758,7 @@ async def _run_single_turn(
745758
hooks=hooks,
746759
context_wrapper=context_wrapper,
747760
run_config=run_config,
761+
tool_use_tracker=tool_use_tracker,
748762
)
749763

750764
@classmethod
@@ -761,6 +775,7 @@ async def _get_single_step_result_from_response(
761775
hooks: RunHooks[TContext],
762776
context_wrapper: RunContextWrapper[TContext],
763777
run_config: RunConfig,
778+
tool_use_tracker: AgentToolUseTracker,
764779
) -> SingleStepResult:
765780
processed_response = RunImpl.process_model_response(
766781
agent=agent,
@@ -769,6 +784,9 @@ async def _get_single_step_result_from_response(
769784
output_schema=output_schema,
770785
handoffs=handoffs,
771786
)
787+
788+
tool_use_tracker.add_tool_use(agent, processed_response.tools_used)
789+
772790
return await RunImpl.execute_tools_and_side_effects(
773791
agent=agent,
774792
original_input=original_input,
@@ -868,9 +886,12 @@ async def _get_new_response(
868886
handoffs: list[Handoff],
869887
context_wrapper: RunContextWrapper[TContext],
870888
run_config: RunConfig,
889+
tool_use_tracker: AgentToolUseTracker,
871890
) -> ModelResponse:
872891
model = cls._get_model(agent, run_config)
873892
model_settings = agent.model_settings.resolve(run_config.model_settings)
893+
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
894+
874895
new_response = await model.get_response(
875896
system_instructions=system_prompt,
876897
input=input,

tests/fake_model.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import AsyncIterator
4+
from typing import Any
45

56
from openai.types.responses import Response, ResponseCompletedEvent
67

@@ -31,6 +32,7 @@ def __init__(
3132
[initial_output] if initial_output else []
3233
)
3334
self.tracing_enabled = tracing_enabled
35+
self.last_turn_args: dict[str, Any] = {}
3436

3537
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
3638
self.turn_outputs.append(output)
@@ -53,6 +55,14 @@ async def get_response(
5355
handoffs: list[Handoff],
5456
tracing: ModelTracing,
5557
) -> ModelResponse:
58+
self.last_turn_args = {
59+
"system_instructions": system_instructions,
60+
"input": input,
61+
"model_settings": model_settings,
62+
"tools": tools,
63+
"output_schema": output_schema,
64+
}
65+
5666
with generation_span(disabled=not self.tracing_enabled) as span:
5767
output = self.get_next_output()
5868

0 commit comments

Comments
 (0)