From c374ad064faa6824d5c5a6c5bd9870688526bc81 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Tue, 11 Mar 2025 22:53:48 +0200 Subject: [PATCH 1/3] Run make format --- src/agents/agent_output.py | 2 +- src/agents/model_settings.py | 1 + tests/src/agents/agent_output.py | 2 +- tests/src/agents/model_settings.py | 1 + tests/test_config.py | 9 ++++++--- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 8140d8c6..0c28800f 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -138,7 +138,7 @@ def _type_to_str(t: type[Any]) -> str: # It's a simple type like `str`, `int`, etc. return t.__name__ elif args: - args_str = ', '.join(_type_to_str(arg) for arg in args) + args_str = ", ".join(_type_to_str(arg) for arg in args) return f"{origin.__name__}[{args_str}]" else: return str(t) diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 78cf9a83..d8178ae3 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -11,6 +11,7 @@ class ModelSettings: This class holds optional model configuration parameters (e.g. temperature, top_p, penalties, truncation, etc.). """ + temperature: float | None = None top_p: float | None = None frequency_penalty: float | None = None diff --git a/tests/src/agents/agent_output.py b/tests/src/agents/agent_output.py index 8140d8c6..0c28800f 100644 --- a/tests/src/agents/agent_output.py +++ b/tests/src/agents/agent_output.py @@ -138,7 +138,7 @@ def _type_to_str(t: type[Any]) -> str: # It's a simple type like `str`, `int`, etc. return t.__name__ elif args: - args_str = ', '.join(_type_to_str(arg) for arg in args) + args_str = ", ".join(_type_to_str(arg) for arg in args) return f"{origin.__name__}[{args_str}]" else: return str(t) diff --git a/tests/src/agents/model_settings.py b/tests/src/agents/model_settings.py index 78cf9a83..d8178ae3 100644 --- a/tests/src/agents/model_settings.py +++ b/tests/src/agents/model_settings.py @@ -11,6 +11,7 @@ class ModelSettings: This class holds optional model configuration parameters (e.g. temperature, top_p, penalties, truncation, etc.). """ + temperature: float | None = None top_p: float | None = None frequency_penalty: float | None = None diff --git a/tests/test_config.py b/tests/test_config.py index 8f37200a..dba854db 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -49,13 +49,16 @@ def test_resp_set_default_openai_client(): def test_set_default_openai_api(): - assert isinstance(OpenAIProvider().get_model("gpt-4"), OpenAIResponsesModel), \ + assert isinstance(OpenAIProvider().get_model("gpt-4"), OpenAIResponsesModel), ( "Default should be responses" + ) set_default_openai_api("chat_completions") - assert isinstance(OpenAIProvider().get_model("gpt-4"), OpenAIChatCompletionsModel), \ + assert isinstance(OpenAIProvider().get_model("gpt-4"), OpenAIChatCompletionsModel), ( "Should be chat completions model" + ) set_default_openai_api("responses") - assert isinstance(OpenAIProvider().get_model("gpt-4"), OpenAIResponsesModel), \ + assert isinstance(OpenAIProvider().get_model("gpt-4"), OpenAIResponsesModel), ( "Should be responses model" + ) From c03d314fb80181858693e915be3d26848e437fa5 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Tue, 11 Mar 2025 22:57:14 +0200 Subject: [PATCH 2/3] Stronger tracing tests with inline-snapshot --- pyproject.toml | 3 +- tests/test_agent_tracing.py | 115 +++++++- tests/test_responses_tracing.py | 33 ++- tests/test_tracing_errors.py | 279 ++++++++++++++++++- tests/test_tracing_errors_streamed.py | 382 +++++++++++++++++++++++++- tests/testing_processor.py | 33 +++ uv.lock | 35 +++ 7 files changed, 875 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9c18d5f6..17265e73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dev = [ "mkdocstrings[python]>=0.28.0", "coverage>=7.6.12", "playwright==1.50.0", + "inline-snapshot>=0.20.5", ] [tool.uv.workspace] members = ["agents"] @@ -116,4 +117,4 @@ filterwarnings = [ ] markers = [ "allow_call_model_methods: mark test as allowing calls to real model implementations", -] \ No newline at end of file +] diff --git a/tests/test_agent_tracing.py b/tests/test_agent_tracing.py index 24bd72f1..3d7196ab 100644 --- a/tests/test_agent_tracing.py +++ b/tests/test_agent_tracing.py @@ -3,12 +3,13 @@ import asyncio import pytest +from inline_snapshot import snapshot from agents import Agent, RunConfig, Runner, trace from .fake_model import FakeModel from .test_responses import get_text_message -from .testing_processor import fetch_ordered_spans, fetch_traces +from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces @pytest.mark.asyncio @@ -25,6 +26,25 @@ async def test_single_run_is_single_trace(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 1, ( f"Got {len(spans)}, but expected 1: the agent span. data:" @@ -52,6 +72,39 @@ async def test_multiple_runs_are_multiple_traces(): traces = fetch_traces() assert len(traces) == 2, f"Expected 2 traces, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + } + ], + }, + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + } + ], + }, + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 2, f"Got {len(spans)}, but expected 2: agent span per run" @@ -79,6 +132,43 @@ async def test_wrapped_trace_is_single_trace(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "test_workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + }, + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + }, + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + }, + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 3, f"Got {len(spans)}, but expected 3: the agent span per run" @@ -97,6 +187,8 @@ async def test_parent_disabled_trace_disabled_agent_trace(): traces = fetch_traces() assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}" + assert fetch_normalized_spans() == snapshot([]) + spans = fetch_ordered_spans() assert len(spans) == 0, ( f"Expected no spans, got {len(spans)}, with {[x.span_data for x in spans]}" @@ -116,6 +208,8 @@ async def test_manual_disabling_works(): traces = fetch_traces() assert len(traces) == 0, f"Expected 0 traces, got {len(traces)}" + assert fetch_normalized_spans() == snapshot([]) + spans = fetch_ordered_spans() assert len(spans) == 0, f"Got {len(spans)}, but expected no spans" @@ -164,6 +258,25 @@ async def test_not_starting_streaming_creates_trace(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 1, f"Got {len(spans)}, but expected 1: the agent span" diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index 82b8e75b..41b87eb3 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -1,4 +1,5 @@ import pytest +from inline_snapshot import snapshot from openai import AsyncOpenAI from openai.types.responses import ResponseCompletedEvent @@ -6,7 +7,7 @@ from agents.tracing.span_data import ResponseSpanData from tests import fake_model -from .testing_processor import fetch_ordered_spans +from .testing_processor import fetch_normalized_spans, fetch_ordered_spans class DummyTracing: @@ -54,6 +55,15 @@ async def dummy_fetch_response( "instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED ) + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "test", + "children": [{"type": "response", "data": {"response_id": "dummy-id"}}], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 1 @@ -82,6 +92,10 @@ async def dummy_fetch_response( "instr", "input", ModelSettings(), [], None, [], ModelTracing.ENABLED_WITHOUT_DATA ) + assert fetch_normalized_spans() == snapshot( + [{"workflow_name": "test", "children": [{"type": "response"}]}] + ) + spans = fetch_ordered_spans() assert len(spans) == 1 assert spans[0].span_data.response is None @@ -107,6 +121,8 @@ async def dummy_fetch_response( "instr", "input", ModelSettings(), [], None, [], ModelTracing.DISABLED ) + assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}]) + spans = fetch_ordered_spans() assert len(spans) == 0 @@ -139,6 +155,15 @@ async def __aiter__(self): ): pass + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "test", + "children": [{"type": "response", "data": {"response_id": "dummy-id-123"}}], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 1 assert isinstance(spans[0].span_data, ResponseSpanData) @@ -174,6 +199,10 @@ async def __aiter__(self): ): pass + assert fetch_normalized_spans() == snapshot( + [{"workflow_name": "test", "children": [{"type": "response"}]}] + ) + spans = fetch_ordered_spans() assert len(spans) == 1 assert isinstance(spans[0].span_data, ResponseSpanData) @@ -208,5 +237,7 @@ async def __aiter__(self): ): pass + assert fetch_normalized_spans() == snapshot([{"workflow_name": "test"}]) + spans = fetch_ordered_spans() assert len(spans) == 0 diff --git a/tests/test_tracing_errors.py b/tests/test_tracing_errors.py index d57e1a84..5dbd7c1b 100644 --- a/tests/test_tracing_errors.py +++ b/tests/test_tracing_errors.py @@ -4,6 +4,7 @@ from typing import Any import pytest +from inline_snapshot import snapshot from typing_extensions import TypedDict from agents import ( @@ -27,7 +28,7 @@ get_handoff_tool_call, get_text_message, ) -from .testing_processor import fetch_ordered_spans, fetch_traces +from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces @pytest.mark.asyncio @@ -45,6 +46,34 @@ async def test_single_turn_model_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + "children": [ + { + "type": "generation", + "error": { + "message": "Error", + "data": {"name": "ValueError", "message": "test error"}, + }, + } + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}" @@ -80,6 +109,43 @@ async def test_multi_turn_no_handoffs(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent", + "handoffs": [], + "tools": ["foo"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "foo", + "input": '{"a": "b"}', + "output": "tool_result", + }, + }, + { + "type": "generation", + "error": { + "message": "Error", + "data": {"name": "ValueError", "message": "test error"}, + }, + }, + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 4, ( f"should have agent, generation, tool, generation, got {len(spans)} with data: " @@ -110,6 +176,39 @@ async def test_tool_call_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent", + "handoffs": [], + "tools": ["foo"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "error": { + "message": "Error running tool", + "data": { + "tool_name": "foo", + "error": "Invalid JSON input for tool foo: bad_json", + }, + }, + "data": {"name": "foo", "input": "bad_json"}, + }, + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 3, ( f"should have agent, generation, tool spans, got {len(spans)} with data: " @@ -159,6 +258,43 @@ async def test_multiple_handoff_doesnt_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test", + "handoffs": ["test", "test"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "some_function", + "input": '{"a": "b"}', + "output": "result", + }, + }, + {"type": "generation"}, + {"type": "handoff", "data": {"from_agent": "test", "to_agent": "test"}}, + ], + }, + { + "type": "agent", + "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"}, + "children": [{"type": "generation"}], + }, + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 7, ( f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: " @@ -193,6 +329,21 @@ async def test_multiple_final_output_doesnt_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "Foo"}, + "children": [{"type": "generation"}], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 2, ( f"should have 1 agent, 1 generation, got {len(spans)} with data: " @@ -251,6 +402,76 @@ async def test_handoffs_lead_to_correct_agent_spans(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_3", + "handoffs": ["test_agent_1", "test_agent_2"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "some_function", + "input": '{"a": "b"}', + "output": "result", + }, + }, + {"type": "generation"}, + { + "type": "handoff", + "data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"}, + }, + ], + }, + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": ["test_agent_3"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "some_function", + "input": '{"a": "b"}', + "output": "result", + }, + }, + {"type": "generation"}, + { + "type": "handoff", + "data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"}, + }, + ], + }, + { + "type": "agent", + "data": { + "name": "test_agent_3", + "handoffs": ["test_agent_1", "test_agent_2"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [{"type": "generation"}], + }, + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 12, ( f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: " @@ -285,6 +506,38 @@ async def test_max_turns_exceeded(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "error": {"message": "Max turns exceeded", "data": {"max_turns": 2}}, + "data": { + "name": "test", + "handoffs": [], + "tools": ["foo"], + "output_type": "Foo", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": {"name": "foo", "input": "", "output": "result"}, + }, + {"type": "generation"}, + { + "type": "function", + "data": {"name": "foo", "input": "", "output": "result"}, + }, + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 5, ( f"should have 1 agent span, 2 generations, 2 function calls, got " @@ -318,6 +571,30 @@ async def test_guardrail_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "error": { + "message": "Guardrail tripwire triggered", + "data": {"guardrail": "guardrail_function"}, + }, + "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"}, + "children": [ + { + "type": "guardrail", + "data": {"name": "guardrail_function", "triggered": True}, + } + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 2, ( f"should have 1 agent, 1 guardrail, got {len(spans)} with data: " diff --git a/tests/test_tracing_errors_streamed.py b/tests/test_tracing_errors_streamed.py index 00f440ee..74cda2de 100644 --- a/tests/test_tracing_errors_streamed.py +++ b/tests/test_tracing_errors_streamed.py @@ -5,6 +5,7 @@ from typing import Any import pytest +from inline_snapshot import snapshot from typing_extensions import TypedDict from agents import ( @@ -32,7 +33,7 @@ get_handoff_tool_call, get_text_message, ) -from .testing_processor import fetch_ordered_spans, fetch_traces +from .testing_processor import fetch_normalized_spans, fetch_ordered_spans, fetch_traces @pytest.mark.asyncio @@ -52,6 +53,35 @@ async def test_single_turn_model_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "error": {"message": "Error in agent run", "data": {"error": "test error"}}, + "data": { + "name": "test_agent", + "handoffs": [], + "tools": [], + "output_type": "str", + }, + "children": [ + { + "type": "generation", + "error": { + "message": "Error", + "data": {"name": "ValueError", "message": "test error"}, + }, + } + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 2, f"should have agent and generation spans, got {len(spans)}" @@ -89,6 +119,44 @@ async def test_multi_turn_no_handoffs(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "error": {"message": "Error in agent run", "data": {"error": "test error"}}, + "data": { + "name": "test_agent", + "handoffs": [], + "tools": ["foo"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "foo", + "input": '{"a": "b"}', + "output": "tool_result", + }, + }, + { + "type": "generation", + "error": { + "message": "Error", + "data": {"name": "ValueError", "message": "test error"}, + }, + }, + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 4, ( f"should have agent, generation, tool, generation, got {len(spans)} with data: " @@ -121,6 +189,43 @@ async def test_tool_call_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "error": { + "message": "Error in agent run", + "data": {"error": "Invalid JSON input for tool foo: bad_json"}, + }, + "data": { + "name": "test_agent", + "handoffs": [], + "tools": ["foo"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "error": { + "message": "Error running tool", + "data": { + "tool_name": "foo", + "error": "Invalid JSON input for tool foo: bad_json", + }, + }, + "data": {"name": "foo", "input": "bad_json"}, + }, + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 3, ( f"should have agent, generation, tool spans, got {len(spans)} with data: " @@ -173,6 +278,43 @@ async def test_multiple_handoff_doesnt_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test", + "handoffs": ["test", "test"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "some_function", + "input": '{"a": "b"}', + "output": "result", + }, + }, + {"type": "generation"}, + {"type": "handoff", "data": {"from_agent": "test", "to_agent": "test"}}, + ], + }, + { + "type": "agent", + "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"}, + "children": [{"type": "generation"}], + }, + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 7, ( f"should have 2 agent, 1 function, 3 generation, 1 handoff, got {len(spans)} with data: " @@ -211,6 +353,21 @@ async def test_multiple_final_output_no_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "Foo"}, + "children": [{"type": "generation"}], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 2, ( f"should have 1 agent, 1 generation, got {len(spans)} with data: " @@ -271,12 +428,152 @@ async def test_handoffs_lead_to_correct_agent_spans(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_3", + "handoffs": ["test_agent_1", "test_agent_2"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "some_function", + "input": '{"a": "b"}', + "output": "result", + }, + }, + {"type": "generation"}, + { + "type": "handoff", + "data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"}, + }, + ], + }, + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": ["test_agent_3"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "some_function", + "input": '{"a": "b"}', + "output": "result", + }, + }, + {"type": "generation"}, + { + "type": "handoff", + "data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"}, + }, + ], + }, + { + "type": "agent", + "data": { + "name": "test_agent_3", + "handoffs": ["test_agent_1", "test_agent_2"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [{"type": "generation"}], + }, + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 12, ( f"should have 3 agents, 2 function, 5 generation, 2 handoff, got {len(spans)} with data: " f"{[x.span_data for x in spans]}" ) + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "data": { + "name": "test_agent_3", + "handoffs": ["test_agent_1", "test_agent_2"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "some_function", + "input": '{"a": "b"}', + "output": "result", + }, + }, + {"type": "generation"}, + { + "type": "handoff", + "data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"}, + }, + ], + }, + { + "type": "agent", + "data": { + "name": "test_agent_1", + "handoffs": ["test_agent_3"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": { + "name": "some_function", + "input": '{"a": "b"}', + "output": "result", + }, + }, + {"type": "generation"}, + { + "type": "handoff", + "data": {"from_agent": "test_agent_1", "to_agent": "test_agent_3"}, + }, + ], + }, + { + "type": "agent", + "data": { + "name": "test_agent_3", + "handoffs": ["test_agent_1", "test_agent_2"], + "tools": ["some_function"], + "output_type": "str", + }, + "children": [{"type": "generation"}], + }, + ], + } + ] + ) + @pytest.mark.asyncio async def test_max_turns_exceeded(): @@ -307,6 +604,38 @@ async def test_max_turns_exceeded(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "error": {"message": "Max turns exceeded", "data": {"max_turns": 2}}, + "data": { + "name": "test", + "handoffs": [], + "tools": ["foo"], + "output_type": "Foo", + }, + "children": [ + {"type": "generation"}, + { + "type": "function", + "data": {"name": "foo", "input": "", "output": "result"}, + }, + {"type": "generation"}, + { + "type": "function", + "data": {"name": "foo", "input": "", "output": "result"}, + }, + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 5, ( f"should have 1 agent, 2 generations, 2 function calls, got " @@ -347,6 +676,33 @@ async def test_input_guardrail_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "error": { + "message": "Guardrail tripwire triggered", + "data": { + "guardrail": "input_guardrail_function", + "type": "input_guardrail", + }, + }, + "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"}, + "children": [ + { + "type": "guardrail", + "data": {"name": "input_guardrail_function", "triggered": True}, + } + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 2, ( f"should have 1 agent, 1 guardrail, got {len(spans)} with data: " @@ -387,6 +743,30 @@ async def test_output_guardrail_error(): traces = fetch_traces() assert len(traces) == 1, f"Expected 1 trace, got {len(traces)}" + assert fetch_normalized_spans() == snapshot( + [ + { + "workflow_name": "Agent workflow", + "children": [ + { + "type": "agent", + "error": { + "message": "Guardrail tripwire triggered", + "data": {"guardrail": "output_guardrail_function"}, + }, + "data": {"name": "test", "handoffs": [], "tools": [], "output_type": "str"}, + "children": [ + { + "type": "guardrail", + "data": {"name": "output_guardrail_function", "triggered": True}, + } + ], + } + ], + } + ] + ) + spans = fetch_ordered_spans() assert len(spans) == 2, ( f"should have 1 agent, 1 guardrail, got {len(spans)} with data: " diff --git a/tests/testing_processor.py b/tests/testing_processor.py index 258a08dc..e5cb6f50 100644 --- a/tests/testing_processor.py +++ b/tests/testing_processor.py @@ -1,6 +1,7 @@ from __future__ import annotations import threading +from datetime import datetime from typing import Any, Literal from agents.tracing import Span, Trace, TracingProcessor @@ -77,3 +78,35 @@ def fetch_traces() -> list[Trace]: def fetch_events() -> list[TestSpanProcessorEvent]: return SPAN_PROCESSOR_TESTING._events + + +def fetch_normalized_spans(): + nodes: dict[tuple[str, str | None], dict[str, Any]] = {} + traces = [] + for trace_obj in fetch_traces(): + trace = trace_obj.export() + assert trace.pop("object") == "trace" + assert trace.pop("id").startswith("trace_") + trace = {k: v for k, v in trace.items() if v is not None} + nodes[(trace_obj.trace_id, None)] = trace + traces.append(trace) + + if not traces: + assert not fetch_ordered_spans() + + for span_obj in fetch_ordered_spans(): + span = span_obj.export() + assert span.pop("object") == "trace.span" + assert span.pop("id").startswith("span_") + assert datetime.fromisoformat(span.pop("started_at")) + assert datetime.fromisoformat(span.pop("ended_at")) + parent_id = span.pop("parent_id") + assert "type" not in span + span_data = span.pop("span_data") + span = {"type": span_data.pop("type")} | {k: v for k, v in span.items() if v is not None} + span_data = {k: v for k, v in span_data.items() if v is not None} + if span_data: + span["data"] = span_data + nodes[(span_obj.trace_id, span_obj.span_id)] = span + nodes[(span.pop("trace_id"), parent_id)].setdefault("children", []).append(span) + return traces diff --git a/uv.lock b/uv.lock index 2bceea75..fd28b2b6 100644 --- a/uv.lock +++ b/uv.lock @@ -26,6 +26,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 }, ] +[[package]] +name = "asttokens" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918 }, +] + [[package]] name = "babel" version = "2.17.0" @@ -240,6 +249,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, ] +[[package]] +name = "executing" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, +] + [[package]] name = "ghp-import" version = "2.1.0" @@ -392,6 +410,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, ] +[[package]] +name = "inline-snapshot" +version = "0.20.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "rich" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3b/95/9b85a63031c168dd1c479f8cfd5cae42d42d6ac41c18dd760a104bc87ddc/inline_snapshot-0.20.5.tar.gz", hash = "sha256:d8b67c6d533c0a3f566e72608144b54da65dc3da5d0dba4169b2c56b75530fb5", size = 92215 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/71/34e775bbf0bcf81d588d80a1df93437f937b0df9a841f246606a03fc5eff/inline_snapshot-0.20.5-py3-none-any.whl", hash = "sha256:3aa56acf5985d89f17ebd4df4aef00faacc49f10cdf4e6b42be701ffc9702b5a", size = 48071 }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -797,6 +830,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "coverage" }, + { name = "inline-snapshot" }, { name = "mkdocs" }, { name = "mkdocs-material" }, { name = "mkdocstrings", extra = ["python"] }, @@ -822,6 +856,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "coverage", specifier = ">=7.6.12" }, + { name = "inline-snapshot", specifier = ">=0.20.5" }, { name = "mkdocs", specifier = ">=1.6.0" }, { name = "mkdocs-material", specifier = ">=9.6.0" }, { name = "mkdocstrings", extras = ["python"], specifier = ">=0.28.0" }, From 7eb2bcee15b8077c4ce002df59af4a44de2b62d8 Mon Sep 17 00:00:00 2001 From: Alex Hall Date: Mon, 17 Mar 2025 23:56:42 +0200 Subject: [PATCH 3/3] mypy --- tests/testing_processor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/testing_processor.py b/tests/testing_processor.py index e5cb6f50..371ea865 100644 --- a/tests/testing_processor.py +++ b/tests/testing_processor.py @@ -85,6 +85,7 @@ def fetch_normalized_spans(): traces = [] for trace_obj in fetch_traces(): trace = trace_obj.export() + assert trace assert trace.pop("object") == "trace" assert trace.pop("id").startswith("trace_") trace = {k: v for k, v in trace.items() if v is not None} @@ -96,6 +97,7 @@ def fetch_normalized_spans(): for span_obj in fetch_ordered_spans(): span = span_obj.export() + assert span assert span.pop("object") == "trace.span" assert span.pop("id").startswith("span_") assert datetime.fromisoformat(span.pop("started_at"))