Skip to content

Commit

Permalink
feat: include complete turn response in Agent.create_turn
Browse files Browse the repository at this point in the history
Summary:

In #102, we made a turn's behavior more complete by automatically passing back the tool response and create another turn when client tool is used.

However, this creates a problem with the non-streaming API where the response object only contains information since the last tool call.

This PR is a hacky attemp to address this, by combining the Turn responses into one. I think ideally we should move all the loop logic to only be on the server side, where a turn would pause and the client SDK would pass tool reponses back to resume a turn.

I also changed it to not yield ToolResponseMessage but instead yield a proper ToolExecutionStep event so that it can be treated the same as server side tool execution in terms of logging. I.e. it now outputs:
"tool_execution> Tool:load_url Response:{"content": "\nToday Google announced that they have released the source code to PebbleOS. This is massive for Rebble, and will accelerate our" instead of "CustomTool> {"content": "\nToday Google announced that they have released the source code to PebbleOS. This is massive for Rebble, and will accelerate our efforts to "

Test Plan:

Added test in meta-llama/llama-stack#1078

Run a simple script with Agent and client tool. Observe the returned response has steps from both created turns.

Turn(
│   input_messages=[
│   │   UserMessage(
│   │   │   content='load https://llama-stack.readthedocs.io/en/latest/introduction/index.html and summarize it',
│   │   │   role='user',
│   │   │   context=None
│   │   )
│   ],
│   output_message=CompletionMessage(
│   │   content="The document from the given URL is about Google releasing the source code to PebbleOS, which is a significant development for Rebble. This allows Rebble to accelerate its efforts to produce new hardware. Rebble had been working on its own replacement firmware, RebbleOS, but the release of PebbleOS's source code will help Rebble to build a production-ready real-time OS for the Pebble.",
│   │   role='assistant',
│   │   stop_reason='end_of_turn',
│   │   tool_calls=[]
│   ),
│   session_id='dec1c6c0-ed9b-42c1-97d7-906871acd5ba',
│   started_at=datetime.datetime(2025, 2, 12, 16, 38, 14, 643186),
│   steps=[
│   │   InferenceStep(
│   │   │   api_model_response=CompletionMessage(
│   │   │   │   content='',
│   │   │   │   role='assistant',
│   │   │   │   stop_reason='end_of_turn',
│   │   │   │   tool_calls=[
│   │   │   │   │   ToolCall(
│   │   │   │   │   │   arguments={'url': 'https://llama-stack.readthedocs.io/en/latest/introduction/index.html'},
│   │   │   │   │   │   call_id='5d09151b-8a53-4292-be8d-f21e134d5142',
│   │   │   │   │   │   tool_name='load_url'
│   │   │   │   │   )
│   │   │   │   ]
│   │   │   ),
│   │   │   step_id='d724a238-d02b-4d77-a4bc-a978a54979c6',
│   │   │   step_type='inference',
│   │   │   turn_id='0496c654-cd02-48bb-a2ab-d1a0a5e91aba',
│   │   │   completed_at=datetime.datetime(2025, 2, 12, 16, 38, 15, 523310),
│   │   │   started_at=datetime.datetime(2025, 2, 12, 16, 38, 14, 654535)
│   │   ),
│   │   ToolExecutionStep(
│   │   │   step_id='49f19a5e-6a1e-4b1c-9232-fbafb82f2f89',
│   │   │   step_type='tool_execution',
│   │   │   tool_calls=[
│   │   │   │   ToolCall(
│   │   │   │   │   arguments={'url': 'https://llama-stack.readthedocs.io/en/latest/introduction/index.html'},
│   │   │   │   │   call_id='5d09151b-8a53-4292-be8d-f21e134d5142',
│   │   │   │   │   tool_name='load_url'
│   │   │   │   )
│   │   │   ],
│   │   │   tool_responses=[
│   │   │   │   ToolResponse(
│   │   │   │   │   call_id='5d09151b-8a53-4292-be8d-f21e134d5142',
│   │   │   │   │   content='{"content": "\nToday Google announced that they have released the source code to PebbleOS. This is massive for Rebble, and will accelerate our efforts to produce new hardware.\n\nPreviously, we have been working on our own replacement firmware: RebbleOS. As you can see by the commit history though, progress was slow. Building a production-ready realtime OS for the Pebble is no small feat, and although we were confident we’d get there given enough time, it was never our ideal path. Thanks to the hard work of many people both within Google and not, we finally have our hands on the original source code for PebbleOS. You can read Google’s blog post on this for even more information.\n\nThis does not mean we instantly have the ability to start developing updates for PebbleOS though, we first will need to spend some concentrated time getting it to build. But before we talk about that, let’s talk about Rebble itself.\n"}',
│   │   │   │   │   tool_name='load_url'
│   │   │   │   )
│   │   │   ],
│   │   │   turn_id='0496c654-cd02-48bb-a2ab-d1a0a5e91aba',
│   │   │   completed_at=datetime.datetime(2025, 2, 12, 16, 38, 15, 534830),
│   │   │   started_at=datetime.datetime(2025, 2, 12, 16, 38, 15, 534756)
│   │   ),
│   │   InferenceStep(
│   │   │   api_model_response=CompletionMessage(
│   │   │   │   content="The document from the given URL is about Google releasing the source code to PebbleOS, which is a significant development for Rebble. This allows Rebble to accelerate its efforts to produce new hardware. Rebble had been working on its own replacement firmware, RebbleOS, but the release of PebbleOS's source code will help Rebble to build a production-ready real-time OS for the Pebble.",
│   │   │   │   role='assistant',
│   │   │   │   stop_reason='end_of_turn',
│   │   │   │   tool_calls=[]
│   │   │   ),
│   │   │   step_id='5e6daa91-e689-4d7a-a7f9-d7c3da2eca5a',
│   │   │   step_type='inference',
│   │   │   turn_id='8f65d88d-7643-4dd7-acc7-48cd9e8aa449',
│   │   │   completed_at=datetime.datetime(2025, 2, 12, 16, 38, 16, 179107),
│   │   │   started_at=datetime.datetime(2025, 2, 12, 16, 38, 15, 561449)
│   │   )
│   ],
│   turn_id='0496c654-cd02-48bb-a2ab-d1a0a5e91aba',
│   completed_at=datetime.datetime(2025, 2, 12, 16, 38, 16, 191199),
│   output_attachments=[]
)
```
  • Loading branch information
ehhuang committed Feb 13, 2025
1 parent b5dce10 commit 706a1e6
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 21 deletions.
75 changes: 63 additions & 12 deletions src/llama_stack_client/lib/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,21 @@
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.agents.turn import Turn
from llama_stack_client.types.agents.turn_create_params import Document, Toolgroup
from llama_stack_client.types.agents.turn_create_response import AgentTurnResponseStreamChunk
from llama_stack_client.types.agents.turn_create_response import (
AgentTurnResponseStreamChunk,
)
from llama_stack_client.types.agents.turn_response_event import TurnResponseEvent
from llama_stack_client.types.agents.turn_response_event_payload import (
AgentTurnResponseStepCompletePayload,
)
from llama_stack_client.types.shared.tool_call import ToolCall
from llama_stack_client.types.agents.turn import CompletionMessage
from .client_tool import ClientTool
from .tool_parser import ToolParser
from datetime import datetime
import uuid
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
from llama_stack_client.types.tool_response import ToolResponse

DEFAULT_MAX_ITER = 10

Expand Down Expand Up @@ -119,24 +129,36 @@ def create_turn(
stream: bool = True,
) -> Iterator[AgentTurnResponseStreamChunk] | Turn:
if stream:
return self._create_turn_streaming(messages, session_id, toolgroups, documents, stream)
return self._create_turn_streaming(messages, session_id, toolgroups, documents)
else:
chunk = None
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents, stream):
chunks = []
for chunk in self._create_turn_streaming(messages, session_id, toolgroups, documents):
if chunk.event.payload.event_type == "turn_complete":
chunks.append(chunk)
pass
if not chunk:
raise Exception("No chunk returned")
if chunk.event.payload.event_type != "turn_complete":
if not chunks:
raise Exception("Turn did not complete")
return chunk.event.payload.turn

# merge chunks
return Turn(
input_messages=chunks[0].event.payload.turn.input_messages,
output_message=chunks[-1].event.payload.turn.output_message,
session_id=chunks[0].event.payload.turn.session_id,
steps=[step for chunk in chunks for step in chunk.event.payload.turn.steps],
turn_id=chunks[0].event.payload.turn.turn_id,
started_at=chunks[0].event.payload.turn.started_at,
completed_at=chunks[-1].event.payload.turn.completed_at,
output_attachments=[
attachment for chunk in chunks for attachment in chunk.event.payload.turn.output_attachments
],
)

def _create_turn_streaming(
self,
messages: List[Union[UserMessage, ToolResponseMessage]],
session_id: Optional[str] = None,
toolgroups: Optional[List[Toolgroup]] = None,
documents: Optional[List[Document]] = None,
stream: bool = True,
) -> Iterator[AgentTurnResponseStreamChunk]:
stop = False
n_iter = 0
Expand All @@ -161,10 +183,39 @@ def _create_turn_streaming(
elif not tool_calls:
yield chunk
else:
next_message = self._run_tool(tool_calls)
yield next_message
tool_execution_start_time = datetime.now()
tool_response_message = self._run_tool(tool_calls)
tool_execution_step = ToolExecutionStep(
step_type="tool_execution",
step_id=str(uuid.uuid4()),
tool_calls=tool_calls,
tool_responses=[
ToolResponse(
tool_name=tool_response_message.tool_name,
content=tool_response_message.content,
call_id=tool_response_message.call_id,
)
],
turn_id=chunk.event.payload.turn.turn_id,
completed_at=datetime.now(),
started_at=tool_execution_start_time,
)
yield AgentTurnResponseStreamChunk(
event=TurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
event_type="step_complete",
step_id=tool_execution_step.step_id,
step_type="tool_execution",
step_details=tool_execution_step,
)
)
)

# HACK: append the tool execution step to the turn
chunk.event.payload.turn.steps.append(tool_execution_step)
yield chunk

# continue the turn when there's a tool call
stop = False
messages = [next_message]
messages = [tool_response_message]
n_iter += 1
10 changes: 1 addition & 9 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from termcolor import cprint

from llama_stack_client.types import InterleavedContent, ToolResponseMessage
from llama_stack_client.types import InterleavedContent


def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str:
Expand Down Expand Up @@ -70,14 +70,6 @@ def _yield_printable_events(self, chunk, previous_event_type=None, previous_step
yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red")
return

if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield TurnStreamPrintableEvent(role="CustomTool", content=chunk.content, color="green")
return

event = chunk.event
event_type = event.payload.event_type

Expand Down

0 comments on commit 706a1e6

Please sign in to comment.