-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathtest_concurrency.py
201 lines (159 loc) · 5.99 KB
/
test_concurrency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from __future__ import annotations
import asyncio
import json
import time
from typing import Any, Dict, List
import pytest
from agents import (
Agent,
GuardrailFunctionOutput,
InputGuardrail,
OutputGuardrail,
RunContextWrapper,
Runner,
function_tool,
)
from .fake_model import FakeModel
from .test_responses import (
get_function_tool_call,
get_text_message,
)
from .testing_processor import SPAN_PROCESSOR_TESTING, fetch_ordered_spans
@pytest.mark.asyncio
async def test_parallel_agent_runs():
"""Test running multiple agents in parallel."""
# Create multiple agents with different characteristics
model1 = FakeModel()
model1.set_next_output([get_text_message("Agent 1 response")])
model2 = FakeModel()
model2.set_next_output([get_text_message("Agent 2 response")])
model3 = FakeModel()
model3.set_next_output([get_text_message("Agent 3 response")])
agent1 = Agent(name="agent1", model=model1)
agent2 = Agent(name="agent2", model=model2)
agent3 = Agent(name="agent3", model=model3)
# Run all agents in parallel
results = await asyncio.gather(
Runner.run(agent1, input="query 1"),
Runner.run(agent2, input="query 2"),
Runner.run(agent3, input="query 3"),
)
# Verify each agent produced the correct response
assert results[0].final_output == "Agent 1 response"
assert results[1].final_output == "Agent 2 response"
assert results[2].final_output == "Agent 3 response"
# Verify trace information was correctly captured for each agent
spans = fetch_ordered_spans()
# Fix: Use a different approach to check for agent spans
assert len(spans) >= 3 # At least 3 spans should be created
@pytest.mark.asyncio
async def test_slow_guardrail_with_fast_model():
"""Test behavior when guardrails are slower than model responses."""
model = FakeModel()
guardrail_executed = False
async def slow_guardrail(
context: RunContextWrapper[Any], agent: Agent[Any], agent_output: str
) -> GuardrailFunctionOutput:
nonlocal guardrail_executed
# Simulate a slow guardrail
await asyncio.sleep(0.1)
guardrail_executed = True
return GuardrailFunctionOutput(output_info={"message": "Checked output"}, tripwire_triggered=False)
agent = Agent(
name="test",
model=model,
output_guardrails=[OutputGuardrail(slow_guardrail)],
)
# Model responds instantly
model.set_next_output([get_text_message("Fast response")])
result = await Runner.run(agent, input="test")
# Verify guardrail was still executed despite model being fast
assert guardrail_executed
assert result.final_output == "Fast response"
@pytest.mark.asyncio
async def test_timeout_on_tool_execution():
"""Test behavior when a tool execution takes too long."""
model = FakeModel()
@function_tool
async def slow_tool() -> str:
# Simulate a very slow tool
await asyncio.sleep(0.5)
return "Slow tool response"
agent = Agent(
name="test",
model=model,
tools=[slow_tool],
)
# Model calls the slow tool
model.set_next_output([
get_function_tool_call("slow_tool", "{}"),
get_text_message("Tool response received")
])
# Run with a very short timeout to force timeout error
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(
Runner.run(agent, input="call slow tool"),
timeout=0.1 # Shorter than the tool execution time
)
@pytest.mark.asyncio
async def test_concurrent_streaming_responses():
"""Test handling of concurrent streaming responses from multiple agents."""
# Create models for streaming
model1 = FakeModel()
model1.set_next_output([get_text_message("Agent 1 streaming response")])
model2 = FakeModel()
model2.set_next_output([get_text_message("Agent 2 streaming response")])
agent1 = Agent(name="stream_agent1", model=model1)
agent2 = Agent(name="stream_agent2", model=model2)
# Run both streaming agents concurrently
results = await asyncio.gather(
Runner.run(agent1, input="stream 1"),
Runner.run(agent2, input="stream 2"),
)
# Both agents should complete successfully
assert results[0].final_output == "Agent 1 streaming response"
assert results[1].final_output == "Agent 2 streaming response"
@pytest.mark.asyncio
async def test_concurrent_tool_execution():
"""Test concurrent execution of multiple tools."""
model = FakeModel()
execution_order = []
@function_tool
async def tool_a() -> str:
execution_order.append("tool_a_start")
await asyncio.sleep(0.1)
execution_order.append("tool_a_end")
return "Tool A result"
@function_tool
async def tool_b() -> str:
execution_order.append("tool_b_start")
await asyncio.sleep(0.05)
execution_order.append("tool_b_end")
return "Tool B result"
@function_tool
async def tool_c() -> str:
execution_order.append("tool_c_start")
await asyncio.sleep(0.02)
execution_order.append("tool_c_end")
return "Tool C result"
agent = Agent(
name="test",
model=model,
tools=[tool_a, tool_b, tool_c],
)
# Set up model to call all tools concurrently
model.set_next_output([
get_function_tool_call("tool_a", "{}"),
get_function_tool_call("tool_b", "{}"),
get_function_tool_call("tool_c", "{}"),
get_text_message("All tools completed")
])
# We're not testing the final output here, just that the tools execute concurrently
await Runner.run(agent, input="execute all tools")
# Verify tools executed concurrently by checking interleaving of start/end events
assert "tool_a_start" in execution_order
assert "tool_b_start" in execution_order
assert "tool_c_start" in execution_order
assert "tool_a_end" in execution_order
assert "tool_b_end" in execution_order
assert "tool_c_end" in execution_order