Skip to content

Commit

Permalink
test: add test for Agent.create_turn non-streaming response (#1078)
Browse files Browse the repository at this point in the history
Summary:

This tests the fix to the SDK in
meta-llama/llama-stack-client-python#141

Test Plan:

LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/
--safety-shield meta-llama/Llama-Guard-3-8B
  • Loading branch information
ehhuang authored Feb 14, 2025
1 parent 32d1e50 commit 225dd38
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions tests/client-sdk/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_custom_tool(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "-100" in logs_str
assert "CustomTool" in logs_str
assert "get_boiling_point" in logs_str


# TODO: fix this flaky test
Expand Down Expand Up @@ -403,7 +403,7 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config):
logs_str = "".join(logs)
print(logs_str)
assert "-100" in logs_str
assert "CustomTool" in logs_str
assert "get_boiling_point" in logs_str


def test_rag_agent(llama_stack_client, agent_config):
Expand Down Expand Up @@ -527,3 +527,33 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert f"Tool:{tool_name}" in logs_str


def test_create_turn_response(llama_stack_client, agent_config):
client_tool = TestClientTool()
agent_config = {
**agent_config,
"input_shields": [],
"output_shields": [],
"client_tools": [client_tool.get_tool_definition()],
}

agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,))
session_id = agent.create_session(f"test-session-{uuid4()}")

response = agent.create_turn(
messages=[
{
"role": "user",
"content": "What is the boiling point of polyjuice?",
},
],
session_id=session_id,
stream=False,
)
steps = response.steps
assert len(steps) == 3
assert steps[0].step_type == "inference"
assert steps[1].step_type == "tool_execution"
assert steps[1].tool_calls[0].tool_name == "get_boiling_point"
assert steps[2].step_type == "inference"

0 comments on commit 225dd38

Please sign in to comment.