diff --git a/.gitignore b/.gitignore index 2ab5a819..7dd22b88 100644 --- a/.gitignore +++ b/.gitignore @@ -141,4 +141,4 @@ cython_debug/ .ruff_cache/ # PyPI configuration file -.pypirc +.pypirc \ No newline at end of file diff --git a/docs/assets/images/graph.png b/docs/assets/images/graph.png new file mode 100644 index 00000000..13e2d6eb Binary files /dev/null and b/docs/assets/images/graph.png differ diff --git a/docs/visualization.md b/docs/visualization.md new file mode 100644 index 00000000..00f3126d --- /dev/null +++ b/docs/visualization.md @@ -0,0 +1,86 @@ +# Agent Visualization + +Agent visualization allows you to generate a structured graphical representation of agents and their relationships using **Graphviz**. This is useful for understanding how agents, tools, and handoffs interact within an application. + +## Installation + +Install the optional `viz` dependency group: + +```bash +pip install "openai-agents[viz]" +``` + +## Generating a Graph + +You can generate an agent visualization using the `draw_graph` function. This function creates a directed graph where: + +- **Agents** are represented as yellow boxes. +- **Tools** are represented as green ellipses. +- **Handoffs** are directed edges from one agent to another. + +### Example Usage + +```python +from agents import Agent, function_tool +from agents.extensions.visualization import draw_graph + +@function_tool +def get_weather(city: str) -> str: + return f"The weather in {city} is sunny." + +spanish_agent = Agent( + name="Spanish agent", + instructions="You only speak Spanish.", +) + +english_agent = Agent( + name="English agent", + instructions="You only speak English", +) + +triage_agent = Agent( + name="Triage agent", + instructions="Handoff to the appropriate agent based on the language of the request.", + handoffs=[spanish_agent, english_agent], + tools=[get_weather], +) + +draw_graph(triage_agent) +``` + +![Agent Graph](./assets/images/graph.png) + +This generates a graph that visually represents the structure of the **triage agent** and its connections to sub-agents and tools. + + +## Understanding the Visualization + +The generated graph includes: + +- A **start node** (`__start__`) indicating the entry point. +- Agents represented as **rectangles** with yellow fill. +- Tools represented as **ellipses** with green fill. +- Directed edges indicating interactions: + - **Solid arrows** for agent-to-agent handoffs. + - **Dotted arrows** for tool invocations. +- An **end node** (`__end__`) indicating where execution terminates. + +## Customizing the Graph + +### Showing the Graph +By default, `draw_graph` displays the graph inline. To show the graph in a separate window, write the following: + +```python +draw_graph(triage_agent).view() +``` + +### Saving the Graph +By default, `draw_graph` displays the graph inline. To save it as a file, specify a filename: + +```python +draw_graph(triage_agent, filename="agent_graph.png") +``` + +This will generate `agent_graph.png` in the working directory. + + diff --git a/mkdocs.yml b/mkdocs.yml index 941f29ed..a27a6369 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -35,6 +35,7 @@ nav: - multi_agent.md - models.md - config.md + - visualization.md - Voice agents: - voice/quickstart.md - voice/pipeline.md diff --git a/pyproject.toml b/pyproject.toml index 3678c714..eb0bae39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ Repository = "https://github.com/openai/openai-agents-python" [project.optional-dependencies] voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"] +viz = ["graphviz>=0.17"] [dependency-groups] dev = [ @@ -56,7 +57,9 @@ dev = [ "pynput", "textual", "websockets", + "graphviz", ] + [tool.uv.workspace] members = ["agents"] diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py new file mode 100644 index 00000000..5fb35062 --- /dev/null +++ b/src/agents/extensions/visualization.py @@ -0,0 +1,137 @@ +from typing import Optional + +import graphviz # type: ignore + +from agents import Agent +from agents.handoffs import Handoff +from agents.tool import Tool + + +def get_main_graph(agent: Agent) -> str: + """ + Generates the main graph structure in DOT format for the given agent. + + Args: + agent (Agent): The agent for which the graph is to be generated. + + Returns: + str: The DOT format string representing the graph. + """ + parts = [ + """ + digraph G { + graph [splines=true]; + node [fontname="Arial"]; + edge [penwidth=1.5]; + """ + ] + parts.append(get_all_nodes(agent)) + parts.append(get_all_edges(agent)) + parts.append("}") + return "".join(parts) + + +def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: + """ + Recursively generates the nodes for the given agent and its handoffs in DOT format. + + Args: + agent (Agent): The agent for which the nodes are to be generated. + + Returns: + str: The DOT format string representing the nodes. + """ + parts = [] + + # Start and end the graph + parts.append( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + ) + # Ensure parent agent node is colored + if not parent: + parts.append( + f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" + ) + + for tool in agent.tools: + parts.append( + f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, ' + f"fillcolor=lightgreen, width=0.5, height=0.3];" + ) + + for handoff in agent.handoffs: + if isinstance(handoff, Handoff): + parts.append( + f'"{handoff.agent_name}" [label="{handoff.agent_name}", ' + f"shape=box, style=filled, style=rounded, " + f"fillcolor=lightyellow, width=1.5, height=0.8];" + ) + if isinstance(handoff, Agent): + parts.append( + f'"{handoff.name}" [label="{handoff.name}", ' + f"shape=box, style=filled, style=rounded, " + f"fillcolor=lightyellow, width=1.5, height=0.8];" + ) + parts.append(get_all_nodes(handoff)) + + return "".join(parts) + + +def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: + """ + Recursively generates the edges for the given agent and its handoffs in DOT format. + + Args: + agent (Agent): The agent for which the edges are to be generated. + parent (Agent, optional): The parent agent. Defaults to None. + + Returns: + str: The DOT format string representing the edges. + """ + parts = [] + + if not parent: + parts.append(f'"__start__" -> "{agent.name}";') + + for tool in agent.tools: + parts.append(f""" + "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5]; + "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""") + + for handoff in agent.handoffs: + if isinstance(handoff, Handoff): + parts.append(f""" + "{agent.name}" -> "{handoff.agent_name}";""") + if isinstance(handoff, Agent): + parts.append(f""" + "{agent.name}" -> "{handoff.name}";""") + parts.append(get_all_edges(handoff, agent)) + + if not agent.handoffs and not isinstance(agent, Tool): # type: ignore + parts.append(f'"{agent.name}" -> "__end__";') + + return "".join(parts) + + +def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: + """ + Draws the graph for the given agent and optionally saves it as a PNG file. + + Args: + agent (Agent): The agent for which the graph is to be drawn. + filename (str): The name of the file to save the graph as a PNG. + + Returns: + graphviz.Source: The graphviz Source object representing the graph. + """ + dot_code = get_main_graph(agent) + graph = graphviz.Source(dot_code) + + if filename: + graph.render(filename, format="png") + + return graph diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 00000000..6aa86774 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,136 @@ +from unittest.mock import Mock + +import graphviz # type: ignore +import pytest + +from agents import Agent +from agents.extensions.visualization import ( + draw_graph, + get_all_edges, + get_all_nodes, + get_main_graph, +) +from agents.handoffs import Handoff + + +@pytest.fixture +def mock_agent(): + tool1 = Mock() + tool1.name = "Tool1" + tool2 = Mock() + tool2.name = "Tool2" + + handoff1 = Mock(spec=Handoff) + handoff1.agent_name = "Handoff1" + + agent = Mock(spec=Agent) + agent.name = "Agent1" + agent.tools = [tool1, tool2] + agent.handoffs = [handoff1] + + return agent + + +def test_get_main_graph(mock_agent): + result = get_main_graph(mock_agent) + print(result) + assert "digraph G" in result + assert "graph [splines=true];" in result + assert 'node [fontname="Arial"];' in result + assert "edge [penwidth=1.5];" in result + assert ( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in result + ) + assert ( + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in result + ) + assert ( + '"Agent1" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in result + ) + assert ( + '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in result + ) + assert ( + '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in result + ) + assert ( + '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in result + ) + + +def test_get_all_nodes(mock_agent): + result = get_all_nodes(mock_agent) + assert ( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in result + ) + assert ( + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in result + ) + assert ( + '"Agent1" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in result + ) + assert ( + '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in result + ) + assert ( + '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in result + ) + assert ( + '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in result + ) + + +def test_get_all_edges(mock_agent): + result = get_all_edges(mock_agent) + assert '"__start__" -> "Agent1";' in result + assert '"Agent1" -> "__end__";' + assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result + assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result + assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result + assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result + assert '"Agent1" -> "Handoff1";' in result + + +def test_draw_graph(mock_agent): + graph = draw_graph(mock_agent) + assert isinstance(graph, graphviz.Source) + assert "digraph G" in graph.source + assert "graph [splines=true];" in graph.source + assert 'node [fontname="Arial"];' in graph.source + assert "edge [penwidth=1.5];" in graph.source + assert ( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in graph.source + ) + assert ( + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in graph.source + ) + assert ( + '"Agent1" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source + ) + assert ( + '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source + ) + assert ( + '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source + ) + assert ( + '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source + ) diff --git a/uv.lock b/uv.lock index d6eba43f..3ee7f047 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.10'", @@ -348,6 +349,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034 }, ] +[[package]] +name = "graphviz" +version = "0.20.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/83/5a40d19b8347f017e417710907f824915fba411a9befd092e52746b63e9f/graphviz-0.20.3.zip", hash = "sha256:09d6bc81e6a9fa392e7ba52135a9d49f1ed62526f96499325930e87ca1b5925d", size = 256455 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/be/d59db2d1d52697c6adc9eacaf50e8965b6345cc143f671e1ed068818d5cf/graphviz-0.20.3-py3-none-any.whl", hash = "sha256:81f848f2904515d8cd359cc611faba817598d2feaac4027b266aa3eda7b3dde5", size = 47126 }, +] + [[package]] name = "greenlet" version = "3.1.1" @@ -1090,6 +1100,9 @@ dependencies = [ ] [package.optional-dependencies] +viz = [ + { name = "graphviz" }, +] voice = [ { name = "numpy", version = "2.2.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, { name = "websockets" }, @@ -1098,6 +1111,7 @@ voice = [ [package.dev-dependencies] dev = [ { name = "coverage" }, + { name = "graphviz" }, { name = "inline-snapshot" }, { name = "mkdocs" }, { name = "mkdocs-material" }, @@ -1118,6 +1132,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "graphviz", marker = "extra == 'viz'", specifier = ">=0.17" }, { name = "griffe", specifier = ">=1.5.6,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, @@ -1128,10 +1143,12 @@ requires-dist = [ { name = "typing-extensions", specifier = ">=4.12.2,<5" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" }, ] +provides-extras = ["voice", "viz"] [package.metadata.requires-dev] dev = [ { name = "coverage", specifier = ">=7.6.12" }, + { name = "graphviz" }, { name = "inline-snapshot", specifier = ">=0.20.7" }, { name = "mkdocs", specifier = ">=1.6.0" }, { name = "mkdocs-material", specifier = ">=9.6.0" },