forked from openai/openai-agents-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualizations.py
115 lines (86 loc) · 3.32 KB
/
visualizations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import graphviz
from src.agents.agent import Agent
def get_main_graph(agent: Agent) -> str:
"""
Generates the main graph structure in DOT format for the given agent.
Args:
agent (Agent): The agent for which the graph is to be generated.
Returns:
str: The DOT format string representing the graph.
"""
parts = [
"""
digraph G {
graph [splines=true];
node [fontname="Arial"];
edge [penwidth=1.5];
"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];
"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];
"""
]
parts.append(get_all_nodes(agent))
parts.append(get_all_edges(agent))
parts.append("}")
return "".join(parts)
def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
"""
Recursively generates the nodes for the given agent and its handoffs in DOT format.
Args:
agent (Agent): The agent for which the nodes are to be generated.
Returns:
str: The DOT format string representing the nodes.
"""
parts = []
# Ensure parent agent node is colored
if not parent:
parts.append(f"""
"{agent.name}" [label="{agent.name}", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];""")
# Smaller tools (ellipse, green)
for tool in agent.tools:
parts.append(f"""
"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];""")
# Bigger handoffs (rounded box, yellow)
for handoff in agent.handoffs:
parts.append(f"""
"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];""")
parts.append(get_all_nodes(handoff))
return "".join(parts)
def get_all_edges(agent: Agent, parent: Agent = None) -> str:
"""
Recursively generates the edges for the given agent and its handoffs in DOT format.
Args:
agent (Agent): The agent for which the edges are to be generated.
parent (Agent, optional): The parent agent. Defaults to None.
Returns:
str: The DOT format string representing the edges.
"""
parts = []
if not parent:
parts.append(f"""
"__start__" -> "{agent.name}";""")
for tool in agent.tools:
parts.append(f"""
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
if not agent.handoffs:
parts.append(f"""
"{agent.name}" -> "__end__";""")
for handoff in agent.handoffs:
parts.append(f"""
"{agent.name}" -> "{handoff.name}";""")
parts.append(get_all_edges(handoff, agent))
return "".join(parts)
def draw_graph(agent: Agent, filename: str = None) -> graphviz.Source:
"""
Draws the graph for the given agent and optionally saves it as a PNG file.
Args:
agent (Agent): The agent for which the graph is to be drawn.
filename (str): The name of the file to save the graph as a PNG.
Returns:
graphviz.Source: The graphviz Source object representing the graph.
"""
dot_code = get_main_graph(agent)
graph = graphviz.Source(dot_code)
if filename:
graph.render(filename, format="png")
return graph