-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisual_agent.py
75 lines (64 loc) · 2.86 KB
/
visual_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import AsyncGenerator, Sequence
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage
from autogen_core import CancellationToken
import json
from autogen_core import Image
from autogen_agentchat.messages import ToolCallExecutionEvent
from figure_processing import get_figures_from_chunk
import logging
class VisualAgent(AssistantAgent):
def __init__(
self,
name: str,
description: str,
system_message: str,
model_client,
model_client_stream,
chunk_and_figure_pairs: dict,
):
super().__init__(
name=name,
description=description,
system_message=system_message,
model_client=model_client,
model_client_stream=model_client_stream,
)
self.chunk_and_figure_pairs = chunk_and_figure_pairs
async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[AgentEvent | Response, None]:
# Override and insert the multimodal messages into the context here. This is a work around as AutoGen doesn't support multi-modal tool call responses yet.
multi_modal_messages = []
for message in messages:
if isinstance(message, ToolCallExecutionEvent):
try:
ai_search_results = message.content[0].content
results = json.loads(ai_search_results)
multi_modal_content = []
for chunk_id, result in results.items():
cleaned_text, chunk_image_retrievals = get_figures_from_chunk(
self.chunk_and_figure_pairs,
result["Chunk"],
chunk_id=chunk_id,
)
multi_modal_content.append(cleaned_text)
for image in chunk_image_retrievals:
multi_modal_content.append(Image.from_base64(image))
if len(multi_modal_content) > 0:
logging.info("Sending multimodal message")
logging.info("Sending %i messages", len(multi_modal_content))
multi_modal_messages.append(
MultiModalMessage(content=multi_modal_content)
)
except json.JSONDecodeError:
multi_modal_messages.append(message)
else:
multi_modal_messages.append(message)
async for event in super().on_messages_stream(
multi_modal_messages, cancellation_token
):
yield event
async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass