Skip to content

Commit 225dd38

Browse files
authored
test: add test for Agent.create_turn non-streaming response (#1078)
Summary: This tests the fix to the SDK in meta-llama/llama-stack-client-python#141 Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B
1 parent 32d1e50 commit 225dd38

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

tests/client-sdk/agents/test_agents.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def test_custom_tool(llama_stack_client, agent_config):
319319
logs = [str(log) for log in EventLogger().log(response) if log is not None]
320320
logs_str = "".join(logs)
321321
assert "-100" in logs_str
322-
assert "CustomTool" in logs_str
322+
assert "get_boiling_point" in logs_str
323323

324324

325325
# TODO: fix this flaky test
@@ -403,7 +403,7 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
403403
logs_str = "".join(logs)
404404
print(logs_str)
405405
assert "-100" in logs_str
406-
assert "CustomTool" in logs_str
406+
assert "get_boiling_point" in logs_str
407407

408408

409409
def test_rag_agent(llama_stack_client, agent_config):
@@ -527,3 +527,33 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
527527
logs = [str(log) for log in EventLogger().log(response) if log is not None]
528528
logs_str = "".join(logs)
529529
assert f"Tool:{tool_name}" in logs_str
530+
531+
532+
def test_create_turn_response(llama_stack_client, agent_config):
533+
client_tool = TestClientTool()
534+
agent_config = {
535+
**agent_config,
536+
"input_shields": [],
537+
"output_shields": [],
538+
"client_tools": [client_tool.get_tool_definition()],
539+
}
540+
541+
agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
542+
session_id = agent.create_session(f"test-session-{uuid4()}")
543+
544+
response = agent.create_turn(
545+
messages=[
546+
{
547+
"role": "user",
548+
"content": "What is the boiling point of polyjuice?",
549+
},
550+
],
551+
session_id=session_id,
552+
stream=False,
553+
)
554+
steps = response.steps
555+
assert len(steps) == 3
556+
assert steps[0].step_type == "inference"
557+
assert steps[1].step_type == "tool_execution"
558+
assert steps[1].tool_calls[0].tool_name == "get_boiling_point"
559+
assert steps[2].step_type == "inference"

0 commit comments

Comments
 (0)