Skip to content

Commit 8dfd6ff

Browse files
authored
Added support for passing tool_call_id via the RunContextWrapper (#766)
This PR fixes issue: #559 By adding the tool_call_id to the RunContextWrapper prior to calling tools. This gives the ability to access the tool_call_id in the implementation of the tool.
1 parent dcb88e6 commit 8dfd6ff

File tree

8 files changed

+115
-35
lines changed

8 files changed

+115
-35
lines changed

src/agents/_run_impl.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
MCPToolApprovalRequest,
7676
Tool,
7777
)
78+
from .tool_context import ToolContext
7879
from .tracing import (
7980
SpanError,
8081
Trace,
@@ -543,23 +544,24 @@ async def run_single_tool(
543544
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
544545
) -> Any:
545546
with function_span(func_tool.name) as span_fn:
547+
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
546548
if config.trace_include_sensitive_data:
547549
span_fn.span_data.input = tool_call.arguments
548550
try:
549551
_, _, result = await asyncio.gather(
550-
hooks.on_tool_start(context_wrapper, agent, func_tool),
552+
hooks.on_tool_start(tool_context, agent, func_tool),
551553
(
552-
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
554+
agent.hooks.on_tool_start(tool_context, agent, func_tool)
553555
if agent.hooks
554556
else _coro.noop_coroutine()
555557
),
556-
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
558+
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
557559
)
558560

559561
await asyncio.gather(
560-
hooks.on_tool_end(context_wrapper, agent, func_tool, result),
562+
hooks.on_tool_end(tool_context, agent, func_tool, result),
561563
(
562-
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
564+
agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
563565
if agent.hooks
564566
else _coro.noop_coroutine()
565567
),

src/agents/function_schema.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .exceptions import UserError
1414
from .run_context import RunContextWrapper
1515
from .strict_schema import ensure_strict_json_schema
16+
from .tool_context import ToolContext
1617

1718

1819
@dataclass
@@ -237,21 +238,21 @@ def function_schema(
237238
ann = type_hints.get(first_name, first_param.annotation)
238239
if ann != inspect._empty:
239240
origin = get_origin(ann) or ann
240-
if origin is RunContextWrapper:
241+
if origin is RunContextWrapper or origin is ToolContext:
241242
takes_context = True # Mark that the function takes context
242243
else:
243244
filtered_params.append((first_name, first_param))
244245
else:
245246
filtered_params.append((first_name, first_param))
246247

247-
# For parameters other than the first, raise error if any use RunContextWrapper.
248+
# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
248249
for name, param in params[1:]:
249250
ann = type_hints.get(name, param.annotation)
250251
if ann != inspect._empty:
251252
origin = get_origin(ann) or ann
252-
if origin is RunContextWrapper:
253+
if origin is RunContextWrapper or origin is ToolContext:
253254
raise UserError(
254-
f"RunContextWrapper param found at non-first position in function"
255+
f"RunContextWrapper/ToolContext param found at non-first position in function"
255256
f" {func.__name__}"
256257
)
257258
filtered_params.append((name, param))

src/agents/tool.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .items import RunItem
2121
from .logger import logger
2222
from .run_context import RunContextWrapper
23+
from .tool_context import ToolContext
2324
from .tracing import SpanError
2425
from .util import _error_tracing
2526
from .util._types import MaybeAwaitable
@@ -31,8 +32,13 @@
3132

3233
ToolFunctionWithoutContext = Callable[ToolParams, Any]
3334
ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParams], Any]
35+
ToolFunctionWithToolContext = Callable[Concatenate[ToolContext, ToolParams], Any]
3436

35-
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
37+
ToolFunction = Union[
38+
ToolFunctionWithoutContext[ToolParams],
39+
ToolFunctionWithContext[ToolParams],
40+
ToolFunctionWithToolContext[ToolParams],
41+
]
3642

3743

3844
@dataclass
@@ -62,7 +68,7 @@ class FunctionTool:
6268
params_json_schema: dict[str, Any]
6369
"""The JSON schema for the tool's parameters."""
6470

65-
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
71+
on_invoke_tool: Callable[[ToolContext[Any], str], Awaitable[Any]]
6672
"""A function that invokes the tool with the given context and parameters. The params passed
6773
are:
6874
1. The tool run context.
@@ -344,7 +350,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
344350
strict_json_schema=strict_mode,
345351
)
346352

347-
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
353+
async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
348354
try:
349355
json_data: dict[str, Any] = json.loads(input) if input else {}
350356
except Exception as e:
@@ -393,7 +399,7 @@ async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
393399

394400
return result
395401

396-
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
402+
async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
397403
try:
398404
return await _on_invoke_tool_impl(ctx, input)
399405
except Exception as e:

src/agents/tool_context.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from dataclasses import dataclass, field, fields
2+
from typing import Any
3+
4+
from .run_context import RunContextWrapper, TContext
5+
6+
7+
def _assert_must_pass_tool_call_id() -> str:
8+
raise ValueError("tool_call_id must be passed to ToolContext")
9+
10+
@dataclass
11+
class ToolContext(RunContextWrapper[TContext]):
12+
"""The context of a tool call."""
13+
14+
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
15+
"""The ID of the tool call."""
16+
17+
@classmethod
18+
def from_agent_context(
19+
cls, context: RunContextWrapper[TContext], tool_call_id: str
20+
) -> "ToolContext":
21+
"""
22+
Create a ToolContext from a RunContextWrapper.
23+
"""
24+
# Grab the names of the RunContextWrapper's init=True fields
25+
base_values: dict[str, Any] = {
26+
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
27+
}
28+
return cls(tool_call_id=tool_call_id, **base_values)

tests/test_function_tool.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
99
from agents.tool import default_tool_error_function
10+
from agents.tool_context import ToolContext
1011

1112

1213
def argless_function() -> str:
@@ -18,11 +19,11 @@ async def test_argless_function():
1819
tool = function_tool(argless_function)
1920
assert tool.name == "argless_function"
2021

21-
result = await tool.on_invoke_tool(RunContextWrapper(None), "")
22+
result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
2223
assert result == "ok"
2324

2425

25-
def argless_with_context(ctx: RunContextWrapper[str]) -> str:
26+
def argless_with_context(ctx: ToolContext[str]) -> str:
2627
return "ok"
2728

2829

@@ -31,11 +32,11 @@ async def test_argless_with_context():
3132
tool = function_tool(argless_with_context)
3233
assert tool.name == "argless_with_context"
3334

34-
result = await tool.on_invoke_tool(RunContextWrapper(None), "")
35+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
3536
assert result == "ok"
3637

3738
# Extra JSON should not raise an error
38-
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
39+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
3940
assert result == "ok"
4041

4142

@@ -48,15 +49,15 @@ async def test_simple_function():
4849
tool = function_tool(simple_function, failure_error_function=None)
4950
assert tool.name == "simple_function"
5051

51-
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
52+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
5253
assert result == 6
5354

54-
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
55+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
5556
assert result == 3
5657

5758
# Missing required argument should raise an error
5859
with pytest.raises(ModelBehaviorError):
59-
await tool.on_invoke_tool(RunContextWrapper(None), "")
60+
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
6061

6162

6263
class Foo(BaseModel):
@@ -84,7 +85,7 @@ async def test_complex_args_function():
8485
"bar": Bar(x="hello", y=10),
8586
}
8687
)
87-
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
88+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
8889
assert result == "6 hello10 hello"
8990

9091
valid_json = json.dumps(
@@ -93,7 +94,7 @@ async def test_complex_args_function():
9394
"bar": Bar(x="hello", y=10),
9495
}
9596
)
96-
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
97+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
9798
assert result == "3 hello10 hello"
9899

99100
valid_json = json.dumps(
@@ -103,12 +104,12 @@ async def test_complex_args_function():
103104
"baz": "world",
104105
}
105106
)
106-
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
107+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
107108
assert result == "3 hello10 world"
108109

109110
# Missing required argument should raise an error
110111
with pytest.raises(ModelBehaviorError):
111-
await tool.on_invoke_tool(RunContextWrapper(None), '{"foo": {"a": 1}}')
112+
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
112113

113114

114115
def test_function_config_overrides():
@@ -168,7 +169,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
168169
assert tool.params_json_schema[key] == value
169170
assert tool.strict_json_schema
170171

171-
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"data": "hello"}')
172+
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
172173
assert result == "hello_done"
173174

174175
tool_not_strict = FunctionTool(
@@ -183,7 +184,7 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
183184
assert "additionalProperties" not in tool_not_strict.params_json_schema
184185

185186
result = await tool_not_strict.on_invoke_tool(
186-
RunContextWrapper(None), '{"data": "hello", "bar": "baz"}'
187+
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
187188
)
188189
assert result == "hello_done"
189190

@@ -194,7 +195,7 @@ def my_func(a: int, b: int = 5):
194195
raise ValueError("test")
195196

196197
tool = function_tool(my_func)
197-
ctx = RunContextWrapper(None)
198+
ctx = ToolContext(None, tool_call_id="1")
198199

199200
result = await tool.on_invoke_tool(ctx, "")
200201
assert "Invalid JSON" in str(result)
@@ -218,7 +219,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
218219
return f"error_{error.__class__.__name__}"
219220

220221
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
221-
ctx = RunContextWrapper(None)
222+
ctx = ToolContext(None, tool_call_id="1")
222223

223224
result = await tool.on_invoke_tool(ctx, "")
224225
assert result == "error_ModelBehaviorError"
@@ -242,7 +243,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
242243
return f"error_{error.__class__.__name__}"
243244

244245
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
245-
ctx = RunContextWrapper(None)
246+
ctx = ToolContext(None, tool_call_id="1")
246247

247248
result = await tool.on_invoke_tool(ctx, "")
248249
assert result == "error_ModelBehaviorError"

tests/test_function_tool_decorator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77

88
from agents import function_tool
99
from agents.run_context import RunContextWrapper
10+
from agents.tool_context import ToolContext
1011

1112

1213
class DummyContext:
1314
def __init__(self):
1415
self.data = "something"
1516

1617

17-
def ctx_wrapper() -> RunContextWrapper[DummyContext]:
18-
return RunContextWrapper(DummyContext())
18+
def ctx_wrapper() -> ToolContext[DummyContext]:
19+
return ToolContext(context=DummyContext(), tool_call_id="1")
1920

2021

2122
@function_tool
@@ -44,7 +45,7 @@ async def test_sync_no_context_with_args_invocation():
4445

4546

4647
@function_tool
47-
def sync_with_context(ctx: RunContextWrapper[DummyContext], name: str) -> str:
48+
def sync_with_context(ctx: ToolContext[DummyContext], name: str) -> str:
4849
return f"{name}_{ctx.context.data}"
4950

5051

@@ -71,7 +72,7 @@ async def test_async_no_context_invocation():
7172

7273

7374
@function_tool
74-
async def async_with_context(ctx: RunContextWrapper[DummyContext], prefix: str, num: int) -> str:
75+
async def async_with_context(ctx: ToolContext[DummyContext], prefix: str, num: int) -> str:
7576
await asyncio.sleep(0)
7677
return f"{prefix}-{num}-{ctx.context.data}"
7778

tests/test_responses.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,12 @@ def _foo() -> str:
4949
)
5050

5151

52-
def get_function_tool_call(name: str, arguments: str | None = None) -> ResponseOutputItem:
52+
def get_function_tool_call(
53+
name: str, arguments: str | None = None, call_id: str | None = None
54+
) -> ResponseOutputItem:
5355
return ResponseFunctionToolCall(
5456
id="1",
55-
call_id="2",
57+
call_id=call_id or "2",
5658
type="function_call",
5759
name=name,
5860
arguments=arguments or "",

tests/test_run_step_execution.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import json
34
from typing import Any
45

56
import pytest
@@ -26,6 +27,8 @@
2627
RunImpl,
2728
SingleStepResult,
2829
)
30+
from agents.tool import function_tool
31+
from agents.tool_context import ToolContext
2932

3033
from .test_responses import (
3134
get_final_output_message,
@@ -158,6 +161,42 @@ async def test_multiple_tool_calls():
158161
assert isinstance(result.next_step, NextStepRunAgain)
159162

160163

164+
@pytest.mark.asyncio
165+
async def test_multiple_tool_calls_with_tool_context():
166+
async def _fake_tool(context: ToolContext[str], value: str) -> str:
167+
return f"{value}-{context.tool_call_id}"
168+
169+
tool = function_tool(_fake_tool, name_override="fake_tool", failure_error_function=None)
170+
171+
agent = Agent(
172+
name="test",
173+
tools=[tool],
174+
)
175+
response = ModelResponse(
176+
output=[
177+
get_function_tool_call("fake_tool", json.dumps({"value": "123"}), call_id="1"),
178+
get_function_tool_call("fake_tool", json.dumps({"value": "456"}), call_id="2"),
179+
],
180+
usage=Usage(),
181+
response_id=None,
182+
)
183+
184+
result = await get_execute_result(agent, response)
185+
assert result.original_input == "hello"
186+
187+
# 4 items: new message, 2 tool calls, 2 tool call outputs
188+
assert len(result.generated_items) == 4
189+
assert isinstance(result.next_step, NextStepRunAgain)
190+
191+
items = result.generated_items
192+
assert_item_is_function_tool_call(items[0], "fake_tool", json.dumps({"value": "123"}))
193+
assert_item_is_function_tool_call(items[1], "fake_tool", json.dumps({"value": "456"}))
194+
assert_item_is_function_tool_call_output(items[2], "123-1")
195+
assert_item_is_function_tool_call_output(items[3], "456-2")
196+
197+
assert isinstance(result.next_step, NextStepRunAgain)
198+
199+
161200
@pytest.mark.asyncio
162201
async def test_handoff_output_leads_to_handoff_next_step():
163202
agent_1 = Agent(name="test_1")

0 commit comments

Comments
 (0)