diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 69c500ab..73eba8f1 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -98,7 +98,11 @@ def set_default_openai_key(key: str) -> None: If provided, this key will be used instead of the OPENAI_API_KEY environment variable. """ - _config.set_default_openai_key(key) + try: + _config.set_default_openai_key(key) + except Exception as e: + logging.error(f"Error setting default OpenAI key: {e}") + raise def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None: @@ -111,22 +115,34 @@ def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) you'll either need to set the OPENAI_API_KEY environment variable or call set_tracing_export_api_key() with the API key you want to use for tracing. """ - _config.set_default_openai_client(client, use_for_tracing) + try: + _config.set_default_openai_client(client, use_for_tracing) + except Exception as e: + logging.error(f"Error setting default OpenAI client: {e}") + raise def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None: """Set the default API to use for OpenAI LLM requests. By default, we will use the responses API but you can set this to use the chat completions API instead. """ - _config.set_default_openai_api(api) + try: + _config.set_default_openai_api(api) + except Exception as e: + logging.error(f"Error setting default OpenAI API: {e}") + raise -def enable_verbose_stdout_logging(): +def enable_verbose_stdout_logging() -> None: """Enables verbose logging to stdout. This is useful for debugging.""" - for name in ["openai.agents", "openai.agents.tracing"]: - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler(sys.stdout)) + try: + for name in ["openai.agents", "openai.agents.tracing"]: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler(sys.stdout)) + except Exception as e: + logging.error(f"Error enabling verbose stdout logging: {e}") + raise __all__ = [ diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 2c849506..8bd3d01f 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -2,7 +2,7 @@ import asyncio from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union, List from openai.types.responses import ( ResponseComputerToolCall, @@ -65,6 +65,7 @@ class QueueCompleteSentinel: + """A sentinel value to indicate that the queue is complete.""" pass @@ -97,8 +98,7 @@ class ProcessedResponse: computer_actions: list[ToolRunComputerAction] def has_tools_to_run(self) -> bool: - # Handoffs, functions and computer actions need local processing - # Hosted tools have already run, so there's nothing to do. + """Check if there are any tools to run.""" return any( [ self.handoffs, @@ -151,6 +151,7 @@ def generated_items(self) -> list[RunItem]: def get_model_tracing_impl( tracing_disabled: bool, trace_include_sensitive_data: bool ) -> ModelTracing: + """Get the model tracing implementation based on the tracing configuration.""" if tracing_disabled: return ModelTracing.DISABLED elif trace_include_sensitive_data: @@ -176,6 +177,7 @@ async def execute_tools_and_side_effects( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, ) -> SingleStepResult: + """Execute tools and side effects for the current step.""" # Make a copy of the generated items pre_step_items = list(pre_step_items) @@ -271,6 +273,7 @@ def process_model_response( output_schema: AgentOutputSchema | None, handoffs: list[Handoff], ) -> ProcessedResponse: + """Process the model response and extract relevant information.""" items: list[RunItem] = [] run_handoffs = [] @@ -356,6 +359,7 @@ async def execute_function_tool_calls( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> list[RunItem]: + """Execute function tool calls.""" async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> str: @@ -422,6 +426,7 @@ async def execute_computer_actions( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> list[RunItem]: + """Execute computer actions.""" results: list[RunItem] = [] # Need to run these serially, because each action can affect the computer state for action in actions: @@ -451,6 +456,7 @@ async def execute_handoffs( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, ) -> SingleStepResult: + """Execute handoffs.""" # If there is more than one handoff, add tool responses that reject those handoffs if len(run_handoffs) > 1: output_message = "Multiple handoffs detected, ignoring this one." @@ -470,9 +476,20 @@ async def execute_handoffs( actual_handoff = run_handoffs[0] with handoff_span(from_agent=agent.name) as span_handoff: handoff = actual_handoff.handoff - new_agent: Agent[Any] = await handoff.on_invoke_handoff( - context_wrapper, actual_handoff.tool_call.arguments - ) + try: + new_agent: Agent[Any] = await handoff.on_invoke_handoff( + context_wrapper, actual_handoff.tool_call.arguments + ) + except Exception as e: + _utils.attach_error_to_span( + span_handoff, + SpanError( + message="Error invoking handoff", + data={"error": str(e)}, + ) + ) + raise + span_handoff.span_data.to_agent = new_agent.name # Append a tool output item for the handoff @@ -568,6 +585,7 @@ async def execute_final_output( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], ) -> SingleStepResult: + """Execute final output.""" # Run the on_end hooks await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output) @@ -587,6 +605,7 @@ async def run_final_output_hooks( context_wrapper: RunContextWrapper[TContext], final_output: Any, ): + """Run the final output hooks.""" await asyncio.gather( hooks.on_agent_end(context_wrapper, agent, final_output), agent.hooks.on_end(context_wrapper, agent, final_output) @@ -602,8 +621,19 @@ async def run_single_input_guardrail( input: str | list[TResponseInputItem], context: RunContextWrapper[TContext], ) -> InputGuardrailResult: + """Run a single input guardrail.""" with guardrail_span(guardrail.get_name()) as span_guardrail: - result = await guardrail.run(agent, input, context) + try: + result = await guardrail.run(agent, input, context) + except Exception as e: + _utils.attach_error_to_span( + span_guardrail, + SpanError( + message="Error running input guardrail", + data={"error": str(e)}, + ) + ) + raise span_guardrail.span_data.triggered = result.output.tripwire_triggered return result @@ -615,8 +645,19 @@ async def run_single_output_guardrail( agent_output: Any, context: RunContextWrapper[TContext], ) -> OutputGuardrailResult: + """Run a single output guardrail.""" with guardrail_span(guardrail.get_name()) as span_guardrail: - result = await guardrail.run(agent=agent, agent_output=agent_output, context=context) + try: + result = await guardrail.run(agent=agent, agent_output=agent_output, context=context) + except Exception as e: + _utils.attach_error_to_span( + span_guardrail, + SpanError( + message="Error running output guardrail", + data={"error": str(e)}, + ) + ) + raise span_guardrail.span_data.triggered = result.output.tripwire_triggered return result @@ -626,6 +667,7 @@ def stream_step_result_to_queue( step_result: SingleStepResult, queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel], ): + """Stream the step result to the queue.""" for item in step_result.new_step_items: if isinstance(item, MessageOutputItem): event = RunItemStreamEvent(item=item, name="message_output_created") @@ -695,6 +737,7 @@ async def execute( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> RunItem: + """Execute a computer action.""" output_func = ( cls._get_screenshot_async(action.computer_tool.computer, action.tool_call) if isinstance(action.computer_tool.computer, AsyncComputer) @@ -741,25 +784,29 @@ async def _get_screenshot_sync( computer: Computer, tool_call: ResponseComputerToolCall, ) -> str: + """Get a screenshot synchronously.""" action = tool_call.action - if isinstance(action, ActionClick): - computer.click(action.x, action.y, action.button) - elif isinstance(action, ActionDoubleClick): - computer.double_click(action.x, action.y) - elif isinstance(action, ActionDrag): - computer.drag([(p.x, p.y) for p in action.path]) - elif isinstance(action, ActionKeypress): - computer.keypress(action.keys) - elif isinstance(action, ActionMove): - computer.move(action.x, action.y) - elif isinstance(action, ActionScreenshot): - computer.screenshot() - elif isinstance(action, ActionScroll): - computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y) - elif isinstance(action, ActionType): - computer.type(action.text) - elif isinstance(action, ActionWait): - computer.wait() + try: + if isinstance(action, ActionClick): + computer.click(action.x, action.y, action.button) + elif isinstance(action, ActionDoubleClick): + computer.double_click(action.x, action.y) + elif isinstance(action, ActionDrag): + computer.drag([(p.x, p.y) for p in action.path]) + elif isinstance(action, ActionKeypress): + computer.keypress(action.keys) + elif isinstance(action, ActionMove): + computer.move(action.x, action.y) + elif isinstance(action, ActionScreenshot): + computer.screenshot() + elif isinstance(action, ActionScroll): + computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y) + elif isinstance(action, ActionType): + computer.type(action.text) + elif isinstance(action, ActionWait): + computer.wait() + except Exception as e: + raise ModelBehaviorError(f"Error executing computer action: {e}") return computer.screenshot() @@ -769,24 +816,28 @@ async def _get_screenshot_async( computer: AsyncComputer, tool_call: ResponseComputerToolCall, ) -> str: + """Get a screenshot asynchronously.""" action = tool_call.action - if isinstance(action, ActionClick): - await computer.click(action.x, action.y, action.button) - elif isinstance(action, ActionDoubleClick): - await computer.double_click(action.x, action.y) - elif isinstance(action, ActionDrag): - await computer.drag([(p.x, p.y) for p in action.path]) - elif isinstance(action, ActionKeypress): - await computer.keypress(action.keys) - elif isinstance(action, ActionMove): - await computer.move(action.x, action.y) - elif isinstance(action, ActionScreenshot): - await computer.screenshot() - elif isinstance(action, ActionScroll): - await computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y) - elif isinstance(action, ActionType): - await computer.type(action.text) - elif isinstance(action, ActionWait): - await computer.wait() + try: + if isinstance(action, ActionClick): + await computer.click(action.x, action.y, action.button) + elif isinstance(action, ActionDoubleClick): + await computer.double_click(action.x, action.y) + elif isinstance(action, ActionDrag): + await computer.drag([(p.x, p.y) for p in action.path]) + elif isinstance(action, ActionKeypress): + await computer.keypress(action.keys) + elif isinstance(action, ActionMove): + await computer.move(action.x, action.y) + elif isinstance(action, ActionScreenshot): + await computer.screenshot() + elif isinstance(action, ActionScroll): + await computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y) + elif isinstance(action, ActionType): + await computer.type(action.text) + elif isinstance(action, ActionWait): + await computer.wait() + except Exception as e: + raise ModelBehaviorError(f"Error executing computer action: {e}") return await computer.screenshot() diff --git a/src/agents/agent.py b/src/agents/agent.py index 61c0a896..33fe2839 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -132,28 +132,36 @@ def as_tool( async def run_agent(context: RunContextWrapper, input: str) -> str: from .run import Runner - output = await Runner.run( - starting_agent=self, - input=input, - context=context.context, - ) - if custom_output_extractor: - return await custom_output_extractor(output) - - return ItemHelpers.text_message_outputs(output.new_items) + try: + output = await Runner.run( + starting_agent=self, + input=input, + context=context.context, + ) + if custom_output_extractor: + return await custom_output_extractor(output) + + return ItemHelpers.text_message_outputs(output.new_items) + except Exception as e: + logger.error(f"Error running agent as tool: {e}") + raise return run_agent async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None: """Get the system prompt for the agent.""" - if isinstance(self.instructions, str): - return self.instructions - elif callable(self.instructions): - if inspect.iscoroutinefunction(self.instructions): - return await cast(Awaitable[str], self.instructions(run_context, self)) - else: - return cast(str, self.instructions(run_context, self)) - elif self.instructions is not None: - logger.error(f"Instructions must be a string or a function, got {self.instructions}") - - return None + try: + if isinstance(self.instructions, str): + return self.instructions + elif callable(self.instructions): + if inspect.iscoroutinefunction(self.instructions): + return await cast(Awaitable[str], self.instructions(run_context, self)) + else: + return cast(str, self.instructions(run_context, self)) + elif self.instructions is not None: + logger.error(f"Instructions must be a string or a function, got {self.instructions}") + + return None + except Exception as e: + logger.error(f"Error getting system prompt: {e}") + raise diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 0c28800f..d4bde2f8 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -87,31 +87,40 @@ def validate_json(self, json_str: str, partial: bool = False) -> Any: """Validate a JSON string against the output type. Returns the validated object, or raises a `ModelBehaviorError` if the JSON is invalid. """ - validated = _utils.validate_json(json_str, self._type_adapter, partial) - if self._is_wrapped: - if not isinstance(validated, dict): - _utils.attach_error_to_current_span( - SpanError( - message="Invalid JSON", - data={"details": f"Expected a dict, got {type(validated)}"}, + try: + validated = _utils.validate_json(json_str, self._type_adapter, partial) + if self._is_wrapped: + if not isinstance(validated, dict): + _utils.attach_error_to_current_span( + SpanError( + message="Invalid JSON", + data={"details": f"Expected a dict, got {type(validated)}"}, + ) + ) + raise ModelBehaviorError( + f"Expected a dict, got {type(validated)} for JSON: {json_str}" ) - ) - raise ModelBehaviorError( - f"Expected a dict, got {type(validated)} for JSON: {json_str}" - ) - if _WRAPPER_DICT_KEY not in validated: - _utils.attach_error_to_current_span( - SpanError( - message="Invalid JSON", - data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"}, + if _WRAPPER_DICT_KEY not in validated: + _utils.attach_error_to_current_span( + SpanError( + message="Invalid JSON", + data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"}, + ) ) + raise ModelBehaviorError( + f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}" + ) + return validated[_WRAPPER_DICT_KEY] + return validated + except Exception as e: + _utils.attach_error_to_current_span( + SpanError( + message="Error validating JSON", + data={"error": str(e)}, ) - raise ModelBehaviorError( - f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}" - ) - return validated[_WRAPPER_DICT_KEY] - return validated + ) + raise def output_type_name(self) -> str: """The name of the output type.""" @@ -119,6 +128,7 @@ def output_type_name(self) -> str: def _is_subclass_of_base_model_or_dict(t: Any) -> bool: + """Check if a type is a subclass of BaseModel or dict.""" if not isinstance(t, type): return False @@ -131,6 +141,7 @@ def _is_subclass_of_base_model_or_dict(t: Any) -> bool: def _type_to_str(t: type[Any]) -> str: + """Convert a type to its string representation.""" origin = get_origin(t) args = get_args(t) diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 78898f01..93f56d93 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -6,6 +6,7 @@ class AgentsException(Exception): """Base class for all exceptions in the Agents SDK.""" + pass class MaxTurnsExceeded(AgentsException): @@ -13,7 +14,7 @@ class MaxTurnsExceeded(AgentsException): message: str - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message @@ -24,7 +25,7 @@ class ModelBehaviorError(AgentsException): message: str - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message @@ -33,7 +34,7 @@ class UserError(AgentsException): message: str - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message @@ -43,7 +44,7 @@ class InputGuardrailTripwireTriggered(AgentsException): guardrail_result: "InputGuardrailResult" """The result data of the guardrail that was triggered.""" - def __init__(self, guardrail_result: "InputGuardrailResult"): + def __init__(self, guardrail_result: "InputGuardrailResult") -> None: self.guardrail_result = guardrail_result super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" @@ -56,7 +57,7 @@ class OutputGuardrailTripwireTriggered(AgentsException): guardrail_result: "OutputGuardrailResult" """The result data of the guardrail that was triggered.""" - def __init__(self, guardrail_result: "OutputGuardrailResult"): + def __init__(self, guardrail_result: "OutputGuardrailResult") -> None: self.guardrail_result = guardrail_result super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index a4b57672..1c1ebe27 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -87,6 +87,15 @@ class FuncDocumentation: # As of Feb 2025, the automatic style detection in griffe is an Insiders feature. This # code approximates it. def _detect_docstring_style(doc: str) -> DocstringStyle: + """ + Detects the style of a docstring. + + Args: + doc: The docstring to detect the style of. + + Returns: + The detected docstring style. + """ scores: dict[DocstringStyle, int] = {"sphinx": 0, "numpy": 0, "google": 0} # Sphinx style detection: look for :param, :type, :return:, and :rtype: @@ -128,6 +137,9 @@ def _detect_docstring_style(doc: str) -> DocstringStyle: @contextlib.contextmanager def _suppress_griffe_logging(): + """ + Context manager to suppress griffe logging. + """ # Supresses warnings about missing annotations for params logger = logging.getLogger("griffe") previous_level = logger.getEffectiveLevel() @@ -210,131 +222,133 @@ def function_schema( A `FuncSchema` object containing the function's name, description, parameter descriptions, and other metadata. """ - - # 1. Grab docstring info - if use_docstring_info: - doc_info = generate_func_documentation(func, docstring_style) - param_descs = doc_info.param_descriptions or {} - else: - doc_info = None - param_descs = {} - - func_name = name_override or doc_info.name if doc_info else func.__name__ - - # 2. Inspect function signature and get type hints - sig = inspect.signature(func) - type_hints = get_type_hints(func) - params = list(sig.parameters.items()) - takes_context = False - filtered_params = [] - - if params: - first_name, first_param = params[0] - # Prefer the evaluated type hint if available - ann = type_hints.get(first_name, first_param.annotation) - if ann != inspect._empty: - origin = get_origin(ann) or ann - if origin is RunContextWrapper: - takes_context = True # Mark that the function takes context - else: - filtered_params.append((first_name, first_param)) + try: + # 1. Grab docstring info + if use_docstring_info: + doc_info = generate_func_documentation(func, docstring_style) + param_descs = doc_info.param_descriptions or {} else: - filtered_params.append((first_name, first_param)) - - # For parameters other than the first, raise error if any use RunContextWrapper. - for name, param in params[1:]: - ann = type_hints.get(name, param.annotation) - if ann != inspect._empty: - origin = get_origin(ann) or ann - if origin is RunContextWrapper: - raise UserError( - f"RunContextWrapper param found at non-first position in function" - f" {func.__name__}" - ) - filtered_params.append((name, param)) - - # We will collect field definitions for create_model as a dict: - # field_name -> (type_annotation, default_value_or_Field(...)) - fields: dict[str, Any] = {} - - for name, param in filtered_params: - ann = type_hints.get(name, param.annotation) - default = param.default - - # If there's no type hint, assume `Any` - if ann == inspect._empty: - ann = Any - - # If a docstring param description exists, use it - field_description = param_descs.get(name, None) - - # Handle different parameter kinds - if param.kind == param.VAR_POSITIONAL: - # e.g. *args: extend positional args - if get_origin(ann) is tuple: - # e.g. def foo(*args: tuple[int, ...]) -> treat as List[int] - args_of_tuple = get_args(ann) - if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis: - ann = list[args_of_tuple[0]] # type: ignore + doc_info = None + param_descs = {} + + func_name = name_override or doc_info.name if doc_info else func.__name__ + + # 2. Inspect function signature and get type hints + sig = inspect.signature(func) + type_hints = get_type_hints(func) + params = list(sig.parameters.items()) + takes_context = False + filtered_params = [] + + if params: + first_name, first_param = params[0] + # Prefer the evaluated type hint if available + ann = type_hints.get(first_name, first_param.annotation) + if ann != inspect._empty: + origin = get_origin(ann) or ann + if origin is RunContextWrapper: + takes_context = True # Mark that the function takes context else: - ann = list[Any] + filtered_params.append((first_name, first_param)) else: - # If user wrote *args: int, treat as List[int] - ann = list[ann] # type: ignore - - # Default factory to empty list - fields[name] = ( - ann, - Field(default_factory=list, description=field_description), # type: ignore - ) - - elif param.kind == param.VAR_KEYWORD: - # **kwargs handling - if get_origin(ann) is dict: - # e.g. def foo(**kwargs: dict[str, int]) - dict_args = get_args(ann) - if len(dict_args) == 2: - ann = dict[dict_args[0], dict_args[1]] # type: ignore - else: - ann = dict[str, Any] - else: - # e.g. def foo(**kwargs: int) -> Dict[str, int] - ann = dict[str, ann] # type: ignore + filtered_params.append((first_name, first_param)) - fields[name] = ( - ann, - Field(default_factory=dict, description=field_description), # type: ignore - ) + # For parameters other than the first, raise error if any use RunContextWrapper. + for name, param in params[1:]: + ann = type_hints.get(name, param.annotation) + if ann != inspect._empty: + origin = get_origin(ann) or ann + if origin is RunContextWrapper: + raise UserError( + f"RunContextWrapper param found at non-first position in function" + f" {func.__name__}" + ) + filtered_params.append((name, param)) + + # We will collect field definitions for create_model as a dict: + # field_name -> (type_annotation, default_value_or_Field(...)) + fields: dict[str, Any] = {} + + for name, param in filtered_params: + ann = type_hints.get(name, param.annotation) + default = param.default + + # If there's no type hint, assume `Any` + if ann == inspect._empty: + ann = Any + + # If a docstring param description exists, use it + field_description = param_descs.get(name, None) + + # Handle different parameter kinds + if param.kind == param.VAR_POSITIONAL: + # e.g. *args: extend positional args + if get_origin(ann) is tuple: + # e.g. def foo(*args: tuple[int, ...]) -> treat as List[int] + args_of_tuple = get_args(ann) + if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis: + ann = list[args_of_tuple[0]] # type: ignore + else: + ann = list[Any] + else: + # If user wrote *args: int, treat as List[int] + ann = list[ann] # type: ignore - else: - # Normal parameter - if default == inspect._empty: - # Required field + # Default factory to empty list fields[name] = ( ann, - Field(..., description=field_description), + Field(default_factory=list, description=field_description), # type: ignore ) - else: - # Parameter with a default value + + elif param.kind == param.VAR_KEYWORD: + # **kwargs handling + if get_origin(ann) is dict: + # e.g. def foo(**kwargs: dict[str, int]) + dict_args = get_args(ann) + if len(dict_args) == 2: + ann = dict[dict_args[0], dict_args[1]] # type: ignore + else: + ann = dict[str, Any] + else: + # e.g. def foo(**kwargs: int) -> Dict[str, int] + ann = dict[str, ann] # type: ignore + fields[name] = ( ann, - Field(default=default, description=field_description), + Field(default_factory=dict, description=field_description), # type: ignore ) - # 3. Dynamically build a Pydantic model - dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields) - - # 4. Build JSON schema from that model - json_schema = dynamic_model.model_json_schema() - if strict_json_schema: - json_schema = ensure_strict_json_schema(json_schema) - - # 5. Return as a FuncSchema dataclass - return FuncSchema( - name=func_name, - description=description_override or doc_info.description if doc_info else None, - params_pydantic_model=dynamic_model, - params_json_schema=json_schema, - signature=sig, - takes_context=takes_context, - ) + else: + # Normal parameter + if default == inspect._empty: + # Required field + fields[name] = ( + ann, + Field(..., description=field_description), + ) + else: + # Parameter with a default value + fields[name] = ( + ann, + Field(default=default, description=field_description), + ) + + # 3. Dynamically build a Pydantic model + dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields) + + # 4. Build JSON schema from that model + json_schema = dynamic_model.model_json_schema() + if strict_json_schema: + json_schema = ensure_strict_json_schema(json_schema) + + # 5. Return as a FuncSchema dataclass + return FuncSchema( + name=func_name, + description=description_override or doc_info.description if doc_info else None, + params_pydantic_model=dynamic_model, + params_json_schema=json_schema, + signature=sig, + takes_context=takes_context, + ) + except Exception as e: + raise UserError(f"Error generating function schema: {e}") from e diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 5bebcd66..d58063d8 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -97,6 +97,7 @@ class InputGuardrail(Generic[TContext]): """ def get_name(self) -> str: + """Get the name of the guardrail.""" if self.name: return self.name @@ -108,20 +109,24 @@ async def run( input: str | list[TResponseInputItem], context: RunContextWrapper[TContext], ) -> InputGuardrailResult: + """Run the input guardrail.""" if not callable(self.guardrail_function): raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") - output = self.guardrail_function(context, agent, input) - if inspect.isawaitable(output): + try: + output = self.guardrail_function(context, agent, input) + if inspect.isawaitable(output): + return InputGuardrailResult( + guardrail=self, + output=await output, + ) + return InputGuardrailResult( guardrail=self, - output=await output, + output=output, ) - - return InputGuardrailResult( - guardrail=self, - output=output, - ) + except Exception as e: + raise UserError(f"Error running input guardrail: {e}") from e @dataclass @@ -151,6 +156,7 @@ class OutputGuardrail(Generic[TContext]): """ def get_name(self) -> str: + """Get the name of the guardrail.""" if self.name: return self.name @@ -159,24 +165,28 @@ def get_name(self) -> str: async def run( self, context: RunContextWrapper[TContext], agent: Agent[Any], agent_output: Any ) -> OutputGuardrailResult: + """Run the output guardrail.""" if not callable(self.guardrail_function): raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") - output = self.guardrail_function(context, agent, agent_output) - if inspect.isawaitable(output): + try: + output = self.guardrail_function(context, agent, agent_output) + if inspect.isawaitable(output): + return OutputGuardrailResult( + guardrail=self, + agent=agent, + agent_output=agent_output, + output=await output, + ) + return OutputGuardrailResult( guardrail=self, agent=agent, agent_output=agent_output, - output=await output, + output=output, ) - - return OutputGuardrailResult( - guardrail=self, - agent=agent, - agent_output=agent_output, - output=output, - ) + except Exception as e: + raise UserError(f"Error running output guardrail: {e}") from e TContext_co = TypeVar("TContext_co", bound=Any, covariant=True) diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index ac157401..fd25e912 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -28,6 +28,10 @@ @dataclass(frozen=True) class HandoffInputData: + """ + The input data passed to the next agent during a handoff. + """ + input_history: str | tuple[TResponseInputItem, ...] """ The input history before `Runner.run()` was called. @@ -51,7 +55,8 @@ class HandoffInputData: @dataclass class Handoff(Generic[TContext]): - """A handoff is when an agent delegates a task to another agent. + """ + A handoff is when an agent delegates a task to another agent. For example, in a customer support scenario you might have a "triage agent" that determines which agent should handle the user's request, and sub-agents that specialize in different areas like billing, account management, etc. @@ -99,15 +104,42 @@ class Handoff(Generic[TContext]): """ def get_transfer_message(self, agent: Agent[Any]) -> str: + """ + Get the transfer message for the handoff. + + Args: + agent: The agent that is being handed off to. + + Returns: + The transfer message. + """ base = f"{{'assistant': '{agent.name}'}}" return base @classmethod def default_tool_name(cls, agent: Agent[Any]) -> str: + """ + Get the default tool name for the handoff. + + Args: + agent: The agent that is being handed off to. + + Returns: + The default tool name. + """ return _utils.transform_string_function_style(f"transfer_to_{agent.name}") @classmethod def default_tool_description(cls, agent: Agent[Any]) -> str: + """ + Get the default tool description for the handoff. + + Args: + agent: The agent that is being handed off to. + + Returns: + The default tool description. + """ return ( f"Handoff to the {agent.name} agent to handle the request. " f"{agent.handoff_description or ''}" @@ -190,34 +222,43 @@ def handoff( async def _invoke_handoff( ctx: RunContextWrapper[Any], input_json: str | None = None ) -> Agent[Any]: - if input_type is not None and type_adapter is not None: - if input_json is None: - _utils.attach_error_to_current_span( - SpanError( - message="Handoff function expected non-null input, but got None", - data={"details": "input_json is None"}, + try: + if input_type is not None and type_adapter is not None: + if input_json is None: + _utils.attach_error_to_current_span( + SpanError( + message="Handoff function expected non-null input, but got None", + data={"details": "input_json is None"}, + ) ) - ) - raise ModelBehaviorError("Handoff function expected non-null input, but got None") + raise ModelBehaviorError("Handoff function expected non-null input, but got None") - validated_input = _utils.validate_json( - json_str=input_json, - type_adapter=type_adapter, - partial=False, + validated_input = _utils.validate_json( + json_str=input_json, + type_adapter=type_adapter, + partial=False, + ) + input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff) + if inspect.iscoroutinefunction(input_func): + await input_func(ctx, validated_input) + else: + input_func(ctx, validated_input) + elif on_handoff is not None: + no_input_func = cast(OnHandoffWithoutInput, on_handoff) + if inspect.iscoroutinefunction(no_input_func): + await no_input_func(ctx) + else: + no_input_func(ctx) + + return agent + except Exception as e: + _utils.attach_error_to_current_span( + SpanError( + message="Error invoking handoff", + data={"error": str(e)}, + ) ) - input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff) - if inspect.iscoroutinefunction(input_func): - await input_func(ctx, validated_input) - else: - input_func(ctx, validated_input) - elif on_handoff is not None: - no_input_func = cast(OnHandoffWithoutInput, on_handoff) - if inspect.iscoroutinefunction(no_input_func): - await no_input_func(ctx) - else: - no_input_func(ctx) - - return agent + raise tool_name = tool_name_override or Handoff.default_tool_name(agent) tool_description = tool_description_override or Handoff.default_tool_description(agent) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 3543225d..778ea21d 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -115,17 +115,21 @@ async def get_response( | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: - response = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - span_generation, - tracing, - stream=False, - ) + try: + response = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=False, + ) + except Exception as e: + logger.error(f"Error fetching response: {e}") + raise if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") @@ -178,17 +182,21 @@ async def stream_response( | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: - response, stream = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - span_generation, - tracing, - stream=True, - ) + try: + response, stream = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=True, + ) + except Exception as e: + logger.error(f"Error fetching response: {e}") + raise usage: CompletionUsage | None = None state = _StreamingState() @@ -513,22 +521,26 @@ async def _fetch_response( f"Response format: {response_format}\n" ) - ret = await self._get_client().chat.completions.create( - model=self.model, - messages=converted_messages, - tools=converted_tools or NOT_GIVEN, - temperature=self._non_null_or_not_given(model_settings.temperature), - top_p=self._non_null_or_not_given(model_settings.top_p), - frequency_penalty=self._non_null_or_not_given(model_settings.frequency_penalty), - presence_penalty=self._non_null_or_not_given(model_settings.presence_penalty), - max_tokens=self._non_null_or_not_given(model_settings.max_tokens), - tool_choice=tool_choice, - response_format=response_format, - parallel_tool_calls=parallel_tool_calls, - stream=stream, - stream_options={"include_usage": True} if stream else NOT_GIVEN, - extra_headers=_HEADERS, - ) + try: + ret = await self._get_client().chat.completions.create( + model=self.model, + messages=converted_messages, + tools=converted_tools or NOT_GIVEN, + temperature=self._non_null_or_not_given(model_settings.temperature), + top_p=self._non_null_or_not_given(model_settings.top_p), + frequency_penalty=self._non_null_or_not_given(model_settings.frequency_penalty), + presence_penalty=self._non_null_or_not_given(model_settings.presence_penalty), + max_tokens=self._non_null_or_not_given(model_settings.max_tokens), + tool_choice=tool_choice, + response_format=response_format, + parallel_tool_calls=parallel_tool_calls, + stream=stream, + stream_options={"include_usage": True} if stream else NOT_GIVEN, + extra_headers=_HEADERS, + ) + except Exception as e: + logger.error(f"Error creating chat completion: {e}") + raise if isinstance(ret, ChatCompletion): return ret diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index 78765ecb..32783889 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -70,6 +70,21 @@ async def get_response( handoffs: list[Handoff], tracing: ModelTracing, ) -> ModelResponse: + """ + Get a response from the model. + + Args: + system_instructions: The system instructions to use. + input: The input items to the model, in OpenAI Responses format. + model_settings: The model settings to use. + tools: The tools available to the model. + output_schema: The output schema to use. + handoffs: The handoffs available to the model. + tracing: Tracing configuration. + + Returns: + The full model response. + """ with response_span(disabled=tracing.is_disabled()) as span_response: try: response = await self._fetch_response( @@ -205,6 +220,21 @@ async def _fetch_response( handoffs: list[Handoff], stream: Literal[True] | Literal[False] = False, ) -> Response | AsyncStream[ResponseStreamEvent]: + """ + Fetch a response from the model. + + Args: + system_instructions: The system instructions to use. + input: The input items to the model, in OpenAI Responses format. + model_settings: The model settings to use. + tools: The tools available to the model. + output_schema: The output schema to use. + handoffs: The handoffs available to the model. + stream: Whether to stream the response. + + Returns: + The model response or an async stream of response events. + """ list_input = ItemHelpers.input_to_new_input_list(input) parallel_tool_calls = ( @@ -227,24 +257,34 @@ async def _fetch_response( f"Response format: {response_format}\n" ) - return await self._client.responses.create( - instructions=self._non_null_or_not_given(system_instructions), - model=self.model, - input=list_input, - include=converted_tools.includes, - tools=converted_tools.tools, - temperature=self._non_null_or_not_given(model_settings.temperature), - top_p=self._non_null_or_not_given(model_settings.top_p), - truncation=self._non_null_or_not_given(model_settings.truncation), - max_output_tokens=self._non_null_or_not_given(model_settings.max_tokens), - tool_choice=tool_choice, - parallel_tool_calls=parallel_tool_calls, - stream=stream, - extra_headers=_HEADERS, - text=response_format, - ) + try: + return await self._client.responses.create( + instructions=self._non_null_or_not_given(system_instructions), + model=self.model, + input=list_input, + include=converted_tools.includes, + tools=converted_tools.tools, + temperature=self._non_null_or_not_given(model_settings.temperature), + top_p=self._non_null_or_not_given(model_settings.top_p), + truncation=self._non_null_or_not_given(model_settings.truncation), + max_output_tokens=self._non_null_or_not_given(model_settings.max_tokens), + tool_choice=tool_choice, + parallel_tool_calls=parallel_tool_calls, + stream=stream, + extra_headers=_HEADERS, + text=response_format, + ) + except Exception as e: + logger.error(f"Error fetching response: {e}") + raise def _get_client(self) -> AsyncOpenAI: + """ + Get the OpenAI client. + + Returns: + The OpenAI client. + """ if self._client is None: self._client = AsyncOpenAI() return self._client @@ -261,6 +301,15 @@ class Converter: def convert_tool_choice( cls, tool_choice: Literal["auto", "required", "none"] | str | None ) -> response_create_params.ToolChoice | NotGiven: + """ + Convert the tool choice to the appropriate format. + + Args: + tool_choice: The tool choice. + + Returns: + The converted tool choice. + """ if tool_choice is None: return NOT_GIVEN elif tool_choice == "required": @@ -291,6 +340,15 @@ def convert_tool_choice( def get_response_format( cls, output_schema: AgentOutputSchema | None ) -> ResponseTextConfigParam | NotGiven: + """ + Get the response format based on the output schema. + + Args: + output_schema: The output schema. + + Returns: + The response format. + """ if output_schema is None or output_schema.is_plain_text(): return NOT_GIVEN else: @@ -309,6 +367,16 @@ def convert_tools( tools: list[Tool], handoffs: list[Handoff[Any]], ) -> ConvertedTools: + """ + Convert the tools and handoffs to the appropriate format. + + Args: + tools: The tools. + handoffs: The handoffs. + + Returns: + The converted tools and includes. + """ converted_tools: list[ToolParam] = [] includes: list[IncludeLiteral] = [] @@ -329,8 +397,15 @@ def convert_tools( @classmethod def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: - """Returns converted tool and includes""" + """ + Convert a tool to the appropriate format. + + Args: + tool: The tool. + Returns: + The converted tool and includes. + """ if isinstance(tool, FunctionTool): converted_tool: ToolParam = { "name": tool.name, @@ -377,6 +452,15 @@ def _convert_tool(cls, tool: Tool) -> tuple[ToolParam, IncludeLiteral | None]: @classmethod def _convert_handoff_tool(cls, handoff: Handoff) -> ToolParam: + """ + Convert a handoff to the appropriate format. + + Args: + handoff: The handoff. + + Returns: + The converted handoff tool. + """ return { "name": handoff.tool_name, "parameters": handoff.input_json_schema, diff --git a/src/agents/result.py b/src/agents/result.py index 6e806b72..2117d921 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -181,7 +181,8 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: if self._stored_exception: raise self._stored_exception - def _check_errors(self): + def _check_errors(self) -> None: + """Check for errors in the agent run.""" if self.current_turn > self.max_turns: self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") @@ -207,7 +208,8 @@ def _check_errors(self): if exc and isinstance(exc, Exception): self._stored_exception = exc - def _cleanup_tasks(self): + def _cleanup_tasks(self) -> None: + """Clean up the asyncio tasks.""" if self._run_impl_task and not self._run_impl_task.done(): self._run_impl_task.cancel() diff --git a/src/agents/run.py b/src/agents/run.py index dfff7e38..2d4d0d87 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -3,7 +3,7 @@ import asyncio import copy from dataclasses import dataclass, field -from typing import Any, cast +from typing import Any, cast, List, Dict, Union from openai.types.responses import ResponseCompletedEvent diff --git a/src/agents/strict_schema.py b/src/agents/strict_schema.py index 910ad85f..d8394215 100644 --- a/src/agents/strict_schema.py +++ b/src/agents/strict_schema.py @@ -1,13 +1,13 @@ from __future__ import annotations -from typing import Any +from typing import Any, Dict, List, Union from openai import NOT_GIVEN from typing_extensions import TypeGuard from .exceptions import UserError -_EMPTY_SCHEMA = { +_EMPTY_SCHEMA: Dict[str, Any] = { "additionalProperties": False, "type": "object", "properties": {}, @@ -16,8 +16,8 @@ def ensure_strict_json_schema( - schema: dict[str, Any], -) -> dict[str, Any]: + schema: Dict[str, Any], +) -> Dict[str, Any]: """Mutates the given JSON schema to ensure it conforms to the `strict` standard that the OpenAI API expects. """ @@ -26,13 +26,26 @@ def ensure_strict_json_schema( return _ensure_strict_json_schema(schema, path=(), root=schema) -# Adapted from https://github.com/openai/openai-python/blob/main/src/openai/lib/_pydantic.py def _ensure_strict_json_schema( - json_schema: object, + json_schema: Dict[str, Any], *, path: tuple[str, ...], - root: dict[str, object], -) -> dict[str, Any]: + root: Dict[str, Any], +) -> Dict[str, Any]: + """Ensures that the given JSON schema conforms to the `strict` standard. + + Args: + json_schema: The JSON schema to ensure. + path: The path to the current schema. + root: The root schema. + + Returns: + The ensured JSON schema. + + Raises: + TypeError: If the given JSON schema is not a dictionary. + UserError: If additionalProperties is set to True for object types. + """ if not is_dict(json_schema): raise TypeError(f"Expected {json_schema} to be a dictionary; path={path}") @@ -63,8 +76,6 @@ def _ensure_strict_json_schema( "to not use a strict schema." ) - # object types - # { 'type': 'object', 'properties': { 'a': {...} } } properties = json_schema.get("properties") if is_dict(properties): json_schema["required"] = list(properties.keys()) @@ -73,13 +84,10 @@ def _ensure_strict_json_schema( for key, prop_schema in properties.items() } - # arrays - # { 'type': 'array', 'items': {...} } items = json_schema.get("items") if is_dict(items): json_schema["items"] = _ensure_strict_json_schema(items, path=(*path, "items"), root=root) - # unions any_of = json_schema.get("anyOf") if is_list(any_of): json_schema["anyOf"] = [ @@ -87,7 +95,6 @@ def _ensure_strict_json_schema( for i, variant in enumerate(any_of) ] - # intersections all_of = json_schema.get("allOf") if is_list(all_of): if len(all_of) == 1: @@ -101,17 +108,9 @@ def _ensure_strict_json_schema( for i, entry in enumerate(all_of) ] - # strip `None` defaults as there's no meaningful distinction here - # the schema will still be `nullable` and the model will default - # to using `None` anyway if json_schema.get("default", NOT_GIVEN) is None: json_schema.pop("default") - # we can't use `$ref`s if there are also other properties defined, e.g. - # `{"$ref": "...", "description": "my description"}` - # - # so we unravel the ref - # `{"type": "string", "description": "my description"}` ref = json_schema.get("$ref") if ref and has_more_than_n_keys(json_schema, 1): assert isinstance(ref, str), f"Received non-string $ref - {ref}" @@ -122,17 +121,26 @@ def _ensure_strict_json_schema( f"Expected `$ref: {ref}` to resolved to a dictionary but got {resolved}" ) - # properties from the json schema take priority over the ones on the `$ref` json_schema.update({**resolved, **json_schema}) json_schema.pop("$ref") - # Since the schema expanded from `$ref` might not have `additionalProperties: false` applied - # we call `_ensure_strict_json_schema` again to fix the inlined schema and ensure it's valid return _ensure_strict_json_schema(json_schema, path=path, root=root) return json_schema -def resolve_ref(*, root: dict[str, object], ref: str) -> object: +def resolve_ref(*, root: Dict[str, Any], ref: str) -> Any: + """Resolves a JSON schema reference. + + Args: + root: The root schema. + ref: The reference to resolve. + + Returns: + The resolved reference. + + Raises: + ValueError: If the reference format is unexpected. + """ if not ref.startswith("#/"): raise ValueError(f"Unexpected $ref format {ref!r}; Does not start with #/") @@ -148,17 +156,40 @@ def resolve_ref(*, root: dict[str, object], ref: str) -> object: return resolved -def is_dict(obj: object) -> TypeGuard[dict[str, object]]: - # just pretend that we know there are only `str` keys - # as that check is not worth the performance cost +def is_dict(obj: Any) -> TypeGuard[Dict[str, Any]]: + """Checks if the given object is a dictionary. + + Args: + obj: The object to check. + + Returns: + True if the object is a dictionary, False otherwise. + """ return isinstance(obj, dict) -def is_list(obj: object) -> TypeGuard[list[object]]: +def is_list(obj: Any) -> TypeGuard[List[Any]]: + """Checks if the given object is a list. + + Args: + obj: The object to check. + + Returns: + True if the object is a list, False otherwise. + """ return isinstance(obj, list) -def has_more_than_n_keys(obj: dict[str, object], n: int) -> bool: +def has_more_than_n_keys(obj: Dict[str, Any], n: int) -> bool: + """Checks if the given dictionary has more than n keys. + + Args: + obj: The dictionary to check. + n: The number of keys to compare against. + + Returns: + True if the dictionary has more than n keys, False otherwise. + """ i = 0 for _ in obj.keys(): i += 1 diff --git a/src/agents/tool.py b/src/agents/tool.py index 75872680..85d7b232 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -81,7 +81,8 @@ class FileSearchTool: """A filter to apply based on file attributes.""" @property - def name(self): + def name(self) -> str: + """The name of the tool.""" return "file_search" @@ -98,7 +99,8 @@ class WebSearchTool: """The amount of context to use for the search.""" @property - def name(self): + def name(self) -> str: + """The name of the tool.""" return "web_search_preview" @@ -112,7 +114,8 @@ class ComputerTool: """ @property - def name(self): + def name(self) -> str: + """The name of the tool.""" return "computer_use_preview" diff --git a/src/agents/tracing/create.py b/src/agents/tracing/create.py index 8d7fc493..389bbfb0 100644 --- a/src/agents/tracing/create.py +++ b/src/agents/tracing/create.py @@ -52,29 +52,41 @@ def trace( Returns: The newly created trace object. """ - current_trace = GLOBAL_TRACE_PROVIDER.get_current_trace() - if current_trace: - logger.warning( - "Trace already exists. Creating a new trace, but this is probably a mistake." + try: + current_trace = GLOBAL_TRACE_PROVIDER.get_current_trace() + if current_trace: + logger.warning( + "Trace already exists. Creating a new trace, but this is probably a mistake." + ) + + return GLOBAL_TRACE_PROVIDER.create_trace( + name=workflow_name, + trace_id=trace_id, + group_id=group_id, + metadata=metadata, + disabled=disabled, ) - - return GLOBAL_TRACE_PROVIDER.create_trace( - name=workflow_name, - trace_id=trace_id, - group_id=group_id, - metadata=metadata, - disabled=disabled, - ) + except Exception as e: + logger.error(f"Error creating trace: {e}") + raise def get_current_trace() -> Trace | None: """Returns the currently active trace, if present.""" - return GLOBAL_TRACE_PROVIDER.get_current_trace() + try: + return GLOBAL_TRACE_PROVIDER.get_current_trace() + except Exception as e: + logger.error(f"Error getting current trace: {e}") + raise def get_current_span() -> Span[Any] | None: """Returns the currently active span, if present.""" - return GLOBAL_TRACE_PROVIDER.get_current_span() + try: + return GLOBAL_TRACE_PROVIDER.get_current_span() + except Exception as e: + logger.error(f"Error getting current span: {e}") + raise def agent_span( @@ -104,12 +116,16 @@ def agent_span( Returns: The newly created agent span. """ - return GLOBAL_TRACE_PROVIDER.create_span( - span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type), - span_id=span_id, - parent=parent, - disabled=disabled, - ) + try: + return GLOBAL_TRACE_PROVIDER.create_span( + span_data=AgentSpanData(name=name, handoffs=handoffs, tools=tools, output_type=output_type), + span_id=span_id, + parent=parent, + disabled=disabled, + ) + except Exception as e: + logger.error(f"Error creating agent span: {e}") + raise def function_span( @@ -137,12 +153,16 @@ def function_span( Returns: The newly created function span. """ - return GLOBAL_TRACE_PROVIDER.create_span( - span_data=FunctionSpanData(name=name, input=input, output=output), - span_id=span_id, - parent=parent, - disabled=disabled, - ) + try: + return GLOBAL_TRACE_PROVIDER.create_span( + span_data=FunctionSpanData(name=name, input=input, output=output), + span_id=span_id, + parent=parent, + disabled=disabled, + ) + except Exception as e: + logger.error(f"Error creating function span: {e}") + raise def generation_span( @@ -179,14 +199,18 @@ def generation_span( Returns: The newly created generation span. """ - return GLOBAL_TRACE_PROVIDER.create_span( - span_data=GenerationSpanData( - input=input, output=output, model=model, model_config=model_config, usage=usage - ), - span_id=span_id, - parent=parent, - disabled=disabled, - ) + try: + return GLOBAL_TRACE_PROVIDER.create_span( + span_data=GenerationSpanData( + input=input, output=output, model=model, model_config=model_config, usage=usage + ), + span_id=span_id, + parent=parent, + disabled=disabled, + ) + except Exception as e: + logger.error(f"Error creating generation span: {e}") + raise def response_span( @@ -207,12 +231,16 @@ def response_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( - span_data=ResponseSpanData(response=response), - span_id=span_id, - parent=parent, - disabled=disabled, - ) + try: + return GLOBAL_TRACE_PROVIDER.create_span( + span_data=ResponseSpanData(response=response), + span_id=span_id, + parent=parent, + disabled=disabled, + ) + except Exception as e: + logger.error(f"Error creating response span: {e}") + raise def handoff_span( @@ -238,12 +266,16 @@ def handoff_span( Returns: The newly created handoff span. """ - return GLOBAL_TRACE_PROVIDER.create_span( - span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent), - span_id=span_id, - parent=parent, - disabled=disabled, - ) + try: + return GLOBAL_TRACE_PROVIDER.create_span( + span_data=HandoffSpanData(from_agent=from_agent, to_agent=to_agent), + span_id=span_id, + parent=parent, + disabled=disabled, + ) + except Exception as e: + logger.error(f"Error creating handoff span: {e}") + raise def custom_span( @@ -270,12 +302,16 @@ def custom_span( Returns: The newly created custom span. """ - return GLOBAL_TRACE_PROVIDER.create_span( - span_data=CustomSpanData(name=name, data=data or {}), - span_id=span_id, - parent=parent, - disabled=disabled, - ) + try: + return GLOBAL_TRACE_PROVIDER.create_span( + span_data=CustomSpanData(name=name, data=data or {}), + span_id=span_id, + parent=parent, + disabled=disabled, + ) + except Exception as e: + logger.error(f"Error creating custom span: {e}") + raise def guardrail_span( @@ -298,9 +334,13 @@ def guardrail_span( trace/span as the parent. disabled: If True, we will return a Span but the Span will not be recorded. """ - return GLOBAL_TRACE_PROVIDER.create_span( - span_data=GuardrailSpanData(name=name, triggered=triggered), - span_id=span_id, - parent=parent, - disabled=disabled, - ) + try: + return GLOBAL_TRACE_PROVIDER.create_span( + span_data=GuardrailSpanData(name=name, triggered=triggered), + span_id=span_id, + parent=parent, + disabled=disabled, + ) + except Exception as e: + logger.error(f"Error creating guardrail span: {e}") + raise