Skip to content

Commit 7c0bbf6

Browse files
iamarunbrahmaCopilotrysweetekzhu
authored
feat: add support for list of messages as team task input and update Society of Mind Agent (#4500)
* feat: add support for list of messages as team task input * Update society of mind agent to use the list input task --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Ryan Sweet <[email protected]> Co-authored-by: Eric Zhu <[email protected]>
1 parent c714515 commit 7c0bbf6

16 files changed

+360
-133
lines changed

Diff for: python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, AsyncGenerator, List, Mapping, Sequence
2+
from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args
33

44
from autogen_core import CancellationToken
55

66
from ..base import ChatAgent, Response, TaskResult
7-
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
7+
from ..messages import (
8+
AgentMessage,
9+
ChatMessage,
10+
TextMessage,
11+
)
812
from ..state import BaseState
913

1014

@@ -45,8 +49,9 @@ async def on_messages_stream(
4549
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
4650
) -> AsyncGenerator[AgentMessage | Response, None]:
4751
"""Handles incoming messages and returns a stream of messages and
48-
and the final item is the response. The base implementation in :class:`BaseChatAgent`
49-
simply calls :meth:`on_messages` and yields the messages in the response."""
52+
and the final item is the response. The base implementation in
53+
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
54+
the messages in the response."""
5055
response = await self.on_messages(messages, cancellation_token)
5156
for inner_message in response.inner_messages or []:
5257
yield inner_message
@@ -55,7 +60,7 @@ async def on_messages_stream(
5560
async def run(
5661
self,
5762
*,
58-
task: str | ChatMessage | None = None,
63+
task: str | ChatMessage | List[ChatMessage] | None = None,
5964
cancellation_token: CancellationToken | None = None,
6065
) -> TaskResult:
6166
"""Run the agent with the given task and return the result."""
@@ -69,7 +74,14 @@ async def run(
6974
text_msg = TextMessage(content=task, source="user")
7075
input_messages.append(text_msg)
7176
output_messages.append(text_msg)
72-
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
77+
elif isinstance(task, list):
78+
for msg in task:
79+
if isinstance(msg, get_args(ChatMessage)[0]):
80+
input_messages.append(msg)
81+
output_messages.append(msg)
82+
else:
83+
raise ValueError(f"Invalid message type in list: {type(msg)}")
84+
elif isinstance(task, get_args(ChatMessage)[0]):
7385
input_messages.append(task)
7486
output_messages.append(task)
7587
else:
@@ -83,7 +95,7 @@ async def run(
8395
async def run_stream(
8496
self,
8597
*,
86-
task: str | ChatMessage | None = None,
98+
task: str | ChatMessage | List[ChatMessage] | None = None,
8799
cancellation_token: CancellationToken | None = None,
88100
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
89101
"""Run the agent with the given task and return a stream of messages
@@ -99,7 +111,15 @@ async def run_stream(
99111
input_messages.append(text_msg)
100112
output_messages.append(text_msg)
101113
yield text_msg
102-
elif isinstance(task, TextMessage | MultiModalMessage | StopMessage | HandoffMessage):
114+
elif isinstance(task, list):
115+
for msg in task:
116+
if isinstance(msg, get_args(ChatMessage)[0]):
117+
input_messages.append(msg)
118+
output_messages.append(msg)
119+
yield msg
120+
else:
121+
raise ValueError(f"Invalid message type in list: {type(msg)}")
122+
elif isinstance(task, get_args(ChatMessage)[0]):
103123
input_messages.append(task)
104124
output_messages.append(task)
105125
yield task
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import AsyncGenerator, List, Sequence
1+
from typing import Any, AsyncGenerator, List, Mapping, Sequence
22

3-
from autogen_core import CancellationToken, Image
4-
from autogen_core.models import ChatCompletionClient
5-
from autogen_core.models._types import SystemMessage
3+
from autogen_core import CancellationToken
4+
from autogen_core.models import ChatCompletionClient, LLMMessage, SystemMessage, UserMessage
65

76
from autogen_agentchat.base import Response
7+
from autogen_agentchat.state import SocietyOfMindAgentState
88

99
from ..base import TaskResult, Team
1010
from ..messages import (
@@ -32,60 +32,76 @@ class SocietyOfMindAgent(BaseChatAgent):
3232
team (Team): The team of agents to use.
3333
model_client (ChatCompletionClient): The model client to use for preparing responses.
3434
description (str, optional): The description of the agent.
35+
instruction (str, optional): The instruction to use when generating a response using the inner team's messages.
36+
Defaults to :attr:`DEFAULT_INSTRUCTION`. It assumes the role of 'system'.
37+
response_prompt (str, optional): The response prompt to use when generating a response using the inner team's messages.
38+
Defaults to :attr:`DEFAULT_RESPONSE_PROMPT`. It assumes the role of 'system'.
3539
3640
3741
Example:
3842
3943
.. code-block:: python
4044
4145
import asyncio
46+
from autogen_agentchat.ui import Console
4247
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
4348
from autogen_ext.models.openai import OpenAIChatCompletionClient
4449
from autogen_agentchat.teams import RoundRobinGroupChat
45-
from autogen_agentchat.conditions import MaxMessageTermination
50+
from autogen_agentchat.conditions import TextMentionTermination
4651
4752
4853
async def main() -> None:
4954
model_client = OpenAIChatCompletionClient(model="gpt-4o")
5055
51-
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
52-
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
53-
inner_termination = MaxMessageTermination(3)
56+
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a writer, write well.")
57+
agent2 = AssistantAgent(
58+
"assistant2",
59+
model_client=model_client,
60+
system_message="You are an editor, provide critical feedback. Respond with 'APPROVE' if the text addresses all feedbacks.",
61+
)
62+
inner_termination = TextMentionTermination("APPROVE")
5463
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
5564
5665
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
5766
58-
agent3 = AssistantAgent("assistant3", model_client=model_client, system_message="You are a helpful assistant.")
59-
agent4 = AssistantAgent("assistant4", model_client=model_client, system_message="You are a helpful assistant.")
60-
outter_termination = MaxMessageTermination(10)
61-
team = RoundRobinGroupChat([society_of_mind_agent, agent3, agent4], termination_condition=outter_termination)
67+
agent3 = AssistantAgent(
68+
"assistant3", model_client=model_client, system_message="Translate the text to Spanish."
69+
)
70+
team = RoundRobinGroupChat([society_of_mind_agent, agent3], max_turns=2)
6271
63-
stream = team.run_stream(task="Tell me a one-liner joke.")
64-
async for message in stream:
65-
print(message)
72+
stream = team.run_stream(task="Write a short story with a surprising ending.")
73+
await Console(stream)
6674
6775
6876
asyncio.run(main())
6977
"""
7078

79+
DEFAULT_INSTRUCTION = "Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:"
80+
"""str: The default instruction to use when generating a response using the
81+
inner team's messages. The instruction will be prepended to the inner team's
82+
messages when generating a response using the model. It assumes the role of
83+
'system'."""
84+
85+
DEFAULT_RESPONSE_PROMPT = (
86+
"Output a standalone response to the original request, without mentioning any of the intermediate discussion."
87+
)
88+
"""str: The default response prompt to use when generating a response using
89+
the inner team's messages. It assumes the role of 'system'."""
90+
7191
def __init__(
7292
self,
7393
name: str,
7494
team: Team,
7595
model_client: ChatCompletionClient,
7696
*,
7797
description: str = "An agent that uses an inner team of agents to generate responses.",
78-
task_prompt: str = "{transcript}\nContinue.",
79-
response_prompt: str = "Here is a transcript of conversation so far:\n{transcript}\n\\Provide a response to the original request.",
98+
instruction: str = DEFAULT_INSTRUCTION,
99+
response_prompt: str = DEFAULT_RESPONSE_PROMPT,
80100
) -> None:
81101
super().__init__(name=name, description=description)
82102
self._team = team
83103
self._model_client = model_client
84-
if "{transcript}" not in task_prompt:
85-
raise ValueError("The task prompt must contain the '{transcript}' placeholder for the transcript.")
86-
self._task_prompt = task_prompt
87-
if "{transcript}" not in response_prompt:
88-
raise ValueError("The response prompt must contain the '{transcript}' placeholder for the transcript.")
104+
self._instruction = instruction
89105
self._response_prompt = response_prompt
90106

91107
@property
@@ -104,33 +120,41 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
104120
async def on_messages_stream(
105121
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
106122
) -> AsyncGenerator[AgentMessage | Response, None]:
107-
# Build the context.
108-
delta = list(messages)
109-
task: str | None = None
110-
if len(delta) > 0:
111-
task = self._task_prompt.format(transcript=self._create_transcript(delta))
123+
# Prepare the task for the team of agents.
124+
task = list(messages)
112125

113126
# Run the team of agents.
114127
result: TaskResult | None = None
115128
inner_messages: List[AgentMessage] = []
129+
count = 0
116130
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
117131
if isinstance(inner_msg, TaskResult):
118132
result = inner_msg
119133
else:
134+
count += 1
135+
if count <= len(task):
136+
# Skip the task messages.
137+
continue
120138
yield inner_msg
121139
inner_messages.append(inner_msg)
122140
assert result is not None
123141

124-
if len(inner_messages) < 2:
125-
# The first message is the task message so we need at least 2 messages.
142+
if len(inner_messages) == 0:
126143
yield Response(
127144
chat_message=TextMessage(source=self.name, content="No response."), inner_messages=inner_messages
128145
)
129146
else:
130-
prompt = self._response_prompt.format(transcript=self._create_transcript(inner_messages[1:]))
131-
completion = await self._model_client.create(
132-
messages=[SystemMessage(content=prompt)], cancellation_token=cancellation_token
147+
# Generate a response using the model client.
148+
llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)]
149+
llm_messages.extend(
150+
[
151+
UserMessage(content=message.content, source=message.source)
152+
for message in inner_messages
153+
if isinstance(message, TextMessage | MultiModalMessage | StopMessage | HandoffMessage)
154+
]
133155
)
156+
llm_messages.append(SystemMessage(content=self._response_prompt))
157+
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)
134158
assert isinstance(completion.content, str)
135159
yield Response(
136160
chat_message=TextMessage(source=self.name, content=completion.content, models_usage=completion.usage),
@@ -143,17 +167,11 @@ async def on_messages_stream(
143167
async def on_reset(self, cancellation_token: CancellationToken) -> None:
144168
await self._team.reset()
145169

146-
def _create_transcript(self, messages: Sequence[AgentMessage]) -> str:
147-
transcript = ""
148-
for message in messages:
149-
if isinstance(message, TextMessage | StopMessage | HandoffMessage):
150-
transcript += f"{message.source}: {message.content}\n"
151-
elif isinstance(message, MultiModalMessage):
152-
for content in message.content:
153-
if isinstance(content, Image):
154-
transcript += f"{message.source}: [Image]\n"
155-
else:
156-
transcript += f"{message.source}: {content}\n"
157-
else:
158-
raise ValueError(f"Unexpected message type: {message} in {self.__class__.__name__}")
159-
return transcript
170+
async def save_state(self) -> Mapping[str, Any]:
171+
team_state = await self._team.save_state()
172+
state = SocietyOfMindAgentState(inner_team_state=team_state)
173+
return state.model_dump()
174+
175+
async def load_state(self, state: Mapping[str, Any]) -> None:
176+
society_of_mind_state = SocietyOfMindAgentState.model_validate(state)
177+
await self._team.load_state(society_of_mind_state.inner_team_state)

Diff for: python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import AsyncGenerator, Protocol, Sequence
2+
from typing import AsyncGenerator, List, Protocol, Sequence
33

44
from autogen_core import CancellationToken
55

@@ -23,7 +23,7 @@ class TaskRunner(Protocol):
2323
async def run(
2424
self,
2525
*,
26-
task: str | ChatMessage | None = None,
26+
task: str | ChatMessage | List[ChatMessage] | None = None,
2727
cancellation_token: CancellationToken | None = None,
2828
) -> TaskResult:
2929
"""Run the task and return the result.
@@ -36,7 +36,7 @@ async def run(
3636
def run_stream(
3737
self,
3838
*,
39-
task: str | ChatMessage | None = None,
39+
task: str | ChatMessage | List[ChatMessage] | None = None,
4040
cancellation_token: CancellationToken | None = None,
4141
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
4242
"""Run the task and produces a stream of messages and the final result

Diff for: python/packages/autogen-agentchat/src/autogen_agentchat/state/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
MagenticOneOrchestratorState,
99
RoundRobinManagerState,
1010
SelectorManagerState,
11+
SocietyOfMindAgentState,
1112
SwarmManagerState,
1213
TeamState,
1314
)
@@ -22,4 +23,5 @@
2223
"SwarmManagerState",
2324
"MagenticOneOrchestratorState",
2425
"TeamState",
26+
"SocietyOfMindAgentState",
2527
]

Diff for: python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py

+7
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,10 @@ class MagenticOneOrchestratorState(BaseGroupChatManagerState):
7979
n_rounds: int = Field(default=0)
8080
n_stalls: int = Field(default=0)
8181
type: str = Field(default="MagenticOneOrchestratorState")
82+
83+
84+
class SocietyOfMindAgentState(BaseState):
85+
"""State for a Society of Mind agent."""
86+
87+
inner_team_state: Mapping[str, Any] = Field(default_factory=dict)
88+
type: str = Field(default="SocietyOfMindAgentState")

0 commit comments

Comments
 (0)