diff --git a/.github/ISSUE_TEMPLATE/model_provider.md b/.github/ISSUE_TEMPLATE/model_provider.md new file mode 100644 index 00000000..b56cb24e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/model_provider.md @@ -0,0 +1,26 @@ +--- +name: Custom model providers +about: Questions or bugs about using non-OpenAI models +title: '' +labels: bug +assignees: '' + +--- + +### Please read this first + +- **Have you read the custom model provider docs, including the 'Common issues' section?** [Model provider docs](https://openai.github.io/openai-agents-python/models/#using-other-llm-providers) +- **Have you searched for related issues?** Others may have faced similar issues. + +### Describe the question +A clear and concise description of what the question or bug is. + +### Debug information +- Agents SDK version: (e.g. `v0.0.3`) +- Python version (e.g. Python 3.10) + +### Repro steps +Ideally provide a minimal python script that can be run to reproduce the issue. + +### Expected behavior +A clear and concise description of what you expected to happen. diff --git a/docs/models.md b/docs/models.md index 7ad515bc..ab4cefb8 100644 --- a/docs/models.md +++ b/docs/models.md @@ -53,21 +53,41 @@ async def main(): ## Using other LLM providers -Many providers also support the OpenAI API format, which means you can pass a `base_url` to the existing OpenAI model implementations and use them easily. `ModelSettings` is used to configure tuning parameters (e.g., temperature, top_p) for the model you select. +You can use other LLM providers in 3 ways (examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/)): -```python -external_client = AsyncOpenAI( - api_key="EXTERNAL_API_KEY", - base_url="https://api.external.com/v1/", -) +1. [`set_default_openai_client`][agents.set_default_openai_client] is useful in cases where you want to globally use an instance of `AsyncOpenAI` as the LLM client. This is for cases where the LLM provider has an OpenAI compatible API endpoint, and you can set the `base_url` and `api_key`. See a configurable example in [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py). +2. [`ModelProvider`][agents.models.interface.ModelProvider] is at the `Runner.run` level. This lets you say "use a custom model provider for all agents in this run". See a configurable example in [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py). +3. [`Agent.model`][agents.agent.Agent.model] lets you specify the model on a specific Agent instance. This enables you to mix and match different providers for different agents. See a configurable example in [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py). + +In cases where you do not have an API key from `platform.openai.com`, we recommend disabling tracing via `set_tracing_disabled()`, or setting up a [different tracing processor](tracing.md). + +!!! note + + In these examples, we use the Chat Completions API/model, because most LLM providers don't yet support the Responses API. If your LLM provider does support it, we recommend using Responses. + +## Common issues with using other LLM providers + +### Tracing client error 401 + +If you get errors related to tracing, this is because traces are uploaded to OpenAI servers, and you don't have an OpenAI API key. You have three options to resolve this: + +1. Disable tracing entirely: [`set_tracing_disabled(True)`][agents.set_tracing_disabled]. +2. Set an OpenAI key for tracing: [`set_tracing_export_api_key(...)`][agents.set_tracing_export_api_key]. This API key will only be used for uploading traces, and must be from [platform.openai.com](https://platform.openai.com/). +3. Use a non-OpenAI trace processor. See the [tracing docs](tracing.md#custom-tracing-processors). + +### Responses API support + +The SDK uses the Responses API by default, but most other LLM providers don't yet support it. You may see 404s or similar issues as a result. To resolve, you have two options: + +1. Call [`set_default_openai_api("chat_completions")`][agents.set_default_openai_api]. This works if you are setting `OPENAI_API_KEY` and `OPENAI_BASE_URL` via environment vars. +2. Use [`OpenAIChatCompletionsModel`][agents.models.openai_chatcompletions.OpenAIChatCompletionsModel]. There are examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/). + +### Structured outputs support + +Some model providers don't have support for [structured outputs](https://platform.openai.com/docs/guides/structured-outputs). This sometimes results in an error that looks something like this: -spanish_agent = Agent( - name="Spanish agent", - instructions="You only speak Spanish.", - model=OpenAIChatCompletionsModel( - model="EXTERNAL_MODEL_NAME", - openai_client=external_client, - ), - model_settings=ModelSettings(temperature=0.5), -) ``` +BadRequestError: Error code: 400 - {'error': {'message': "'response_format.type' : value is not one of the allowed values ['text','json_object']", 'type': 'invalid_request_error'}} +``` + +This is a shortcoming of some model providers - they support JSON outputs, but don't allow you to specify the `json_schema` to use for the output. We are working on a fix for this, but we suggest relying on providers that do have support for JSON schema output, because otherwise your app will often break because of malformed JSON. diff --git a/examples/model_providers/README.md b/examples/model_providers/README.md new file mode 100644 index 00000000..f9330c24 --- /dev/null +++ b/examples/model_providers/README.md @@ -0,0 +1,19 @@ +# Custom LLM providers + +The examples in this directory demonstrate how you might use a non-OpenAI LLM provider. To run them, first set a base URL, API key and model. + +```bash +export EXAMPLE_BASE_URL="..." +export EXAMPLE_API_KEY="..." +export EXAMPLE_MODEL_NAME"..." +``` + +Then run the examples, e.g.: + +``` +python examples/model_providers/custom_example_provider.py + +Loops within themselves, +Function calls its own being, +Depth without ending. +``` diff --git a/examples/model_providers/custom_example_agent.py b/examples/model_providers/custom_example_agent.py new file mode 100644 index 00000000..f10865c4 --- /dev/null +++ b/examples/model_providers/custom_example_agent.py @@ -0,0 +1,55 @@ +import asyncio +import os + +from openai import AsyncOpenAI + +from agents import Agent, OpenAIChatCompletionsModel, Runner, function_tool, set_tracing_disabled + +BASE_URL = os.getenv("EXAMPLE_BASE_URL") or "" +API_KEY = os.getenv("EXAMPLE_API_KEY") or "" +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "" + +if not BASE_URL or not API_KEY or not MODEL_NAME: + raise ValueError( + "Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code." + ) + +"""This example uses a custom provider for a specific agent. Steps: +1. Create a custom OpenAI client. +2. Create a `Model` that uses the custom client. +3. Set the `model` on the Agent. + +Note that in this example, we disable tracing under the assumption that you don't have an API key +from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var +or call set_tracing_export_api_key() to set a tracing specific key. +""" +client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY) +set_tracing_disabled(disabled=True) + +# An alternate approach that would also work: +# PROVIDER = OpenAIProvider(openai_client=client) +# agent = Agent(..., model="some-custom-model") +# Runner.run(agent, ..., run_config=RunConfig(model_provider=PROVIDER)) + + +@function_tool +def get_weather(city: str): + print(f"[debug] getting weather for {city}") + return f"The weather in {city} is sunny." + + +async def main(): + # This agent will use the custom LLM provider + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client), + tools=[get_weather], + ) + + result = await Runner.run(agent, "What's the weather in Tokyo?") + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/model_providers/custom_example_global.py b/examples/model_providers/custom_example_global.py new file mode 100644 index 00000000..ae9756d3 --- /dev/null +++ b/examples/model_providers/custom_example_global.py @@ -0,0 +1,63 @@ +import asyncio +import os + +from openai import AsyncOpenAI + +from agents import ( + Agent, + Runner, + function_tool, + set_default_openai_api, + set_default_openai_client, + set_tracing_disabled, +) + +BASE_URL = os.getenv("EXAMPLE_BASE_URL") or "" +API_KEY = os.getenv("EXAMPLE_API_KEY") or "" +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "" + +if not BASE_URL or not API_KEY or not MODEL_NAME: + raise ValueError( + "Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code." + ) + + +"""This example uses a custom provider for all requests by default. We do three things: +1. Create a custom client. +2. Set it as the default OpenAI client, and don't use it for tracing. +3. Set the default API as Chat Completions, as most LLM providers don't yet support Responses API. + +Note that in this example, we disable tracing under the assumption that you don't have an API key +from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var +or call set_tracing_export_api_key() to set a tracing specific key. +""" + +client = AsyncOpenAI( + base_url=BASE_URL, + api_key=API_KEY, +) +set_default_openai_client(client=client, use_for_tracing=False) +set_default_openai_api("chat_completions") +set_tracing_disabled(disabled=True) + + +@function_tool +def get_weather(city: str): + print(f"[debug] getting weather for {city}") + return f"The weather in {city} is sunny." + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + model=MODEL_NAME, + tools=[get_weather], + ) + + result = await Runner.run(agent, "What's the weather in Tokyo?") + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/model_providers/custom_example_provider.py b/examples/model_providers/custom_example_provider.py new file mode 100644 index 00000000..4e590198 --- /dev/null +++ b/examples/model_providers/custom_example_provider.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import asyncio +import os + +from openai import AsyncOpenAI + +from agents import ( + Agent, + Model, + ModelProvider, + OpenAIChatCompletionsModel, + RunConfig, + Runner, + function_tool, + set_tracing_disabled, +) + +BASE_URL = os.getenv("EXAMPLE_BASE_URL") or "" +API_KEY = os.getenv("EXAMPLE_API_KEY") or "" +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "" + +if not BASE_URL or not API_KEY or not MODEL_NAME: + raise ValueError( + "Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code." + ) + + +"""This example uses a custom provider for some calls to Runner.run(), and direct calls to OpenAI for +others. Steps: +1. Create a custom OpenAI client. +2. Create a ModelProvider that uses the custom client. +3. Use the ModelProvider in calls to Runner.run(), only when we want to use the custom LLM provider. + +Note that in this example, we disable tracing under the assumption that you don't have an API key +from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var +or call set_tracing_export_api_key() to set a tracing specific key. +""" +client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY) +set_tracing_disabled(disabled=True) + + +class CustomModelProvider(ModelProvider): + def get_model(self, model_name: str | None) -> Model: + return OpenAIChatCompletionsModel(model=model_name or MODEL_NAME, openai_client=client) + + +CUSTOM_MODEL_PROVIDER = CustomModelProvider() + + +@function_tool +def get_weather(city: str): + print(f"[debug] getting weather for {city}") + return f"The weather in {city} is sunny." + + +async def main(): + agent = Agent(name="Assistant", instructions="You only respond in haikus.", tools=[get_weather]) + + # This will use the custom model provider + result = await Runner.run( + agent, + "What's the weather in Tokyo?", + run_config=RunConfig(model_provider=CUSTOM_MODEL_PROVIDER), + ) + print(result.final_output) + + # If you uncomment this, it will use OpenAI directly, not the custom provider + # result = await Runner.run( + # agent, + # "What's the weather in Tokyo?", + # ) + # print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/model_providers/gemini_example.py b/examples/model_providers/gemini_example.py new file mode 100644 index 00000000..38a262a9 --- /dev/null +++ b/examples/model_providers/gemini_example.py @@ -0,0 +1,59 @@ +import os +import asyncio +from dotenv import load_dotenv + +from agents import ( + Agent, + Runner, + GeminiProvider, + RunConfig, + function_tool +) + +# Load environment variables from .env file +load_dotenv() + +# Get the API key from environment variables +gemini_api_key = os.environ.get("GEMINI_API_KEY") + +if not gemini_api_key: + raise ValueError("GEMINI_API_KEY environment variable is not set") + +# Create a Gemini provider with your API key +gemini_provider = GeminiProvider(api_key=gemini_api_key) + +# Define a simple function tool +@function_tool +def get_weather(city: str) -> str: + """Get the current weather for a city.""" + # In a real application, this would call a weather API + return f"The weather in {city} is sunny and 75°F." + +# Define an agent using Gemini +agent = Agent( + name="Gemini Assistant", + instructions="You are a helpful assistant powered by Google Gemini.", + tools=[get_weather], +) + +async def main(): + # Create a run configuration that uses the Gemini provider + config = RunConfig( + model_provider=gemini_provider, + # Specify the model to use (default is "gemini-2.0-flash") + model="gemini-2.0-flash", + ) + + # Run the agent with the Gemini provider + result = await Runner.run( + agent, + "What's the weather like in Tokyo?", + run_config=config, + ) + + # Print the final output + print("\nFinal output:") + print(result.final_output) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 262ce17c..8184a670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-agents" -version = "0.0.3" +version = "0.0.4" description = "OpenAI Agents SDK" readme = "README.md" requires-python = ">=3.9" diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 69c500ab..65a01661 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -44,6 +44,8 @@ from .models.openai_chatcompletions import OpenAIChatCompletionsModel from .models.openai_provider import OpenAIProvider from .models.openai_responses import OpenAIResponsesModel +from .models.gemini_chatcompletions import GeminiChatCompletionsModel +from .models.gemini_provider import GeminiProvider from .result import RunResult, RunResultStreaming from .run import RunConfig, Runner from .run_context import RunContextWrapper, TContext @@ -92,13 +94,24 @@ from .usage import Usage -def set_default_openai_key(key: str) -> None: - """Set the default OpenAI API key to use for LLM requests and tracing. This is only necessary if - the OPENAI_API_KEY environment variable is not already set. +def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None: + """Set the default OpenAI API key to use for LLM requests (and optionally tracing(). This is + only necessary if the OPENAI_API_KEY environment variable is not already set. If provided, this key will be used instead of the OPENAI_API_KEY environment variable. + + Args: + key: The OpenAI key to use. + use_for_tracing: Whether to also use this key to send traces to OpenAI. Defaults to True + If False, you'll either need to set the OPENAI_API_KEY environment variable or call + set_tracing_export_api_key() with the API key you want to use for tracing. """ - _config.set_default_openai_key(key) + _config.set_default_openai_key(key, use_for_tracing) + try: + _config.set_default_openai_key(key) + except Exception as e: + logging.error(f"Error setting default OpenAI key: {e}") + raise def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None: @@ -111,22 +124,37 @@ def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) you'll either need to set the OPENAI_API_KEY environment variable or call set_tracing_export_api_key() with the API key you want to use for tracing. """ - _config.set_default_openai_client(client, use_for_tracing) + try: + _config.set_default_openai_client(client, use_for_tracing) + except Exception as e: + logging.error(f"Error setting default OpenAI client: {e}") + raise def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None: """Set the default API to use for OpenAI LLM requests. By default, we will use the responses API but you can set this to use the chat completions API instead. """ - _config.set_default_openai_api(api) + try: + _config.set_default_openai_api(api) + except Exception as e: + logging.error(f"Error setting default OpenAI API: {e}") + raise -def enable_verbose_stdout_logging(): +def enable_verbose_stdout_logging() -> None: """Enables verbose logging to stdout. This is useful for debugging.""" - for name in ["openai.agents", "openai.agents.tracing"]: - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler(sys.stdout)) + logger = logging.getLogger("openai.agents") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler(sys.stdout)) + try: + for name in ["openai.agents", "openai.agents.tracing"]: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler(sys.stdout)) + except Exception as e: + logging.error(f"Error enabling verbose stdout logging: {e}") + raise __all__ = [ @@ -139,6 +167,8 @@ def enable_verbose_stdout_logging(): "OpenAIChatCompletionsModel", "OpenAIProvider", "OpenAIResponsesModel", + "GeminiChatCompletionsModel", + "GeminiProvider", "AgentOutputSchema", "Computer", "AsyncComputer", diff --git a/src/agents/_config.py b/src/agents/_config.py index 55ded64d..304cfb83 100644 --- a/src/agents/_config.py +++ b/src/agents/_config.py @@ -5,15 +5,18 @@ from .tracing import set_tracing_export_api_key -def set_default_openai_key(key: str) -> None: - set_tracing_export_api_key(key) +def set_default_openai_key(key: str, use_for_tracing: bool) -> None: _openai_shared.set_default_openai_key(key) + if use_for_tracing: + set_tracing_export_api_key(key) + def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool) -> None: + _openai_shared.set_default_openai_client(client) + if use_for_tracing: set_tracing_export_api_key(client.api_key) - _openai_shared.set_default_openai_client(client) def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None: diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 2c849506..8bd3d01f 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -2,7 +2,7 @@ import asyncio from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Union, List from openai.types.responses import ( ResponseComputerToolCall, @@ -65,6 +65,7 @@ class QueueCompleteSentinel: + """A sentinel value to indicate that the queue is complete.""" pass @@ -97,8 +98,7 @@ class ProcessedResponse: computer_actions: list[ToolRunComputerAction] def has_tools_to_run(self) -> bool: - # Handoffs, functions and computer actions need local processing - # Hosted tools have already run, so there's nothing to do. + """Check if there are any tools to run.""" return any( [ self.handoffs, @@ -151,6 +151,7 @@ def generated_items(self) -> list[RunItem]: def get_model_tracing_impl( tracing_disabled: bool, trace_include_sensitive_data: bool ) -> ModelTracing: + """Get the model tracing implementation based on the tracing configuration.""" if tracing_disabled: return ModelTracing.DISABLED elif trace_include_sensitive_data: @@ -176,6 +177,7 @@ async def execute_tools_and_side_effects( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, ) -> SingleStepResult: + """Execute tools and side effects for the current step.""" # Make a copy of the generated items pre_step_items = list(pre_step_items) @@ -271,6 +273,7 @@ def process_model_response( output_schema: AgentOutputSchema | None, handoffs: list[Handoff], ) -> ProcessedResponse: + """Process the model response and extract relevant information.""" items: list[RunItem] = [] run_handoffs = [] @@ -356,6 +359,7 @@ async def execute_function_tool_calls( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> list[RunItem]: + """Execute function tool calls.""" async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall ) -> str: @@ -422,6 +426,7 @@ async def execute_computer_actions( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> list[RunItem]: + """Execute computer actions.""" results: list[RunItem] = [] # Need to run these serially, because each action can affect the computer state for action in actions: @@ -451,6 +456,7 @@ async def execute_handoffs( context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, ) -> SingleStepResult: + """Execute handoffs.""" # If there is more than one handoff, add tool responses that reject those handoffs if len(run_handoffs) > 1: output_message = "Multiple handoffs detected, ignoring this one." @@ -470,9 +476,20 @@ async def execute_handoffs( actual_handoff = run_handoffs[0] with handoff_span(from_agent=agent.name) as span_handoff: handoff = actual_handoff.handoff - new_agent: Agent[Any] = await handoff.on_invoke_handoff( - context_wrapper, actual_handoff.tool_call.arguments - ) + try: + new_agent: Agent[Any] = await handoff.on_invoke_handoff( + context_wrapper, actual_handoff.tool_call.arguments + ) + except Exception as e: + _utils.attach_error_to_span( + span_handoff, + SpanError( + message="Error invoking handoff", + data={"error": str(e)}, + ) + ) + raise + span_handoff.span_data.to_agent = new_agent.name # Append a tool output item for the handoff @@ -568,6 +585,7 @@ async def execute_final_output( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], ) -> SingleStepResult: + """Execute final output.""" # Run the on_end hooks await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output) @@ -587,6 +605,7 @@ async def run_final_output_hooks( context_wrapper: RunContextWrapper[TContext], final_output: Any, ): + """Run the final output hooks.""" await asyncio.gather( hooks.on_agent_end(context_wrapper, agent, final_output), agent.hooks.on_end(context_wrapper, agent, final_output) @@ -602,8 +621,19 @@ async def run_single_input_guardrail( input: str | list[TResponseInputItem], context: RunContextWrapper[TContext], ) -> InputGuardrailResult: + """Run a single input guardrail.""" with guardrail_span(guardrail.get_name()) as span_guardrail: - result = await guardrail.run(agent, input, context) + try: + result = await guardrail.run(agent, input, context) + except Exception as e: + _utils.attach_error_to_span( + span_guardrail, + SpanError( + message="Error running input guardrail", + data={"error": str(e)}, + ) + ) + raise span_guardrail.span_data.triggered = result.output.tripwire_triggered return result @@ -615,8 +645,19 @@ async def run_single_output_guardrail( agent_output: Any, context: RunContextWrapper[TContext], ) -> OutputGuardrailResult: + """Run a single output guardrail.""" with guardrail_span(guardrail.get_name()) as span_guardrail: - result = await guardrail.run(agent=agent, agent_output=agent_output, context=context) + try: + result = await guardrail.run(agent=agent, agent_output=agent_output, context=context) + except Exception as e: + _utils.attach_error_to_span( + span_guardrail, + SpanError( + message="Error running output guardrail", + data={"error": str(e)}, + ) + ) + raise span_guardrail.span_data.triggered = result.output.tripwire_triggered return result @@ -626,6 +667,7 @@ def stream_step_result_to_queue( step_result: SingleStepResult, queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel], ): + """Stream the step result to the queue.""" for item in step_result.new_step_items: if isinstance(item, MessageOutputItem): event = RunItemStreamEvent(item=item, name="message_output_created") @@ -695,6 +737,7 @@ async def execute( context_wrapper: RunContextWrapper[TContext], config: RunConfig, ) -> RunItem: + """Execute a computer action.""" output_func = ( cls._get_screenshot_async(action.computer_tool.computer, action.tool_call) if isinstance(action.computer_tool.computer, AsyncComputer) @@ -741,25 +784,29 @@ async def _get_screenshot_sync( computer: Computer, tool_call: ResponseComputerToolCall, ) -> str: + """Get a screenshot synchronously.""" action = tool_call.action - if isinstance(action, ActionClick): - computer.click(action.x, action.y, action.button) - elif isinstance(action, ActionDoubleClick): - computer.double_click(action.x, action.y) - elif isinstance(action, ActionDrag): - computer.drag([(p.x, p.y) for p in action.path]) - elif isinstance(action, ActionKeypress): - computer.keypress(action.keys) - elif isinstance(action, ActionMove): - computer.move(action.x, action.y) - elif isinstance(action, ActionScreenshot): - computer.screenshot() - elif isinstance(action, ActionScroll): - computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y) - elif isinstance(action, ActionType): - computer.type(action.text) - elif isinstance(action, ActionWait): - computer.wait() + try: + if isinstance(action, ActionClick): + computer.click(action.x, action.y, action.button) + elif isinstance(action, ActionDoubleClick): + computer.double_click(action.x, action.y) + elif isinstance(action, ActionDrag): + computer.drag([(p.x, p.y) for p in action.path]) + elif isinstance(action, ActionKeypress): + computer.keypress(action.keys) + elif isinstance(action, ActionMove): + computer.move(action.x, action.y) + elif isinstance(action, ActionScreenshot): + computer.screenshot() + elif isinstance(action, ActionScroll): + computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y) + elif isinstance(action, ActionType): + computer.type(action.text) + elif isinstance(action, ActionWait): + computer.wait() + except Exception as e: + raise ModelBehaviorError(f"Error executing computer action: {e}") return computer.screenshot() @@ -769,24 +816,28 @@ async def _get_screenshot_async( computer: AsyncComputer, tool_call: ResponseComputerToolCall, ) -> str: + """Get a screenshot asynchronously.""" action = tool_call.action - if isinstance(action, ActionClick): - await computer.click(action.x, action.y, action.button) - elif isinstance(action, ActionDoubleClick): - await computer.double_click(action.x, action.y) - elif isinstance(action, ActionDrag): - await computer.drag([(p.x, p.y) for p in action.path]) - elif isinstance(action, ActionKeypress): - await computer.keypress(action.keys) - elif isinstance(action, ActionMove): - await computer.move(action.x, action.y) - elif isinstance(action, ActionScreenshot): - await computer.screenshot() - elif isinstance(action, ActionScroll): - await computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y) - elif isinstance(action, ActionType): - await computer.type(action.text) - elif isinstance(action, ActionWait): - await computer.wait() + try: + if isinstance(action, ActionClick): + await computer.click(action.x, action.y, action.button) + elif isinstance(action, ActionDoubleClick): + await computer.double_click(action.x, action.y) + elif isinstance(action, ActionDrag): + await computer.drag([(p.x, p.y) for p in action.path]) + elif isinstance(action, ActionKeypress): + await computer.keypress(action.keys) + elif isinstance(action, ActionMove): + await computer.move(action.x, action.y) + elif isinstance(action, ActionScreenshot): + await computer.screenshot() + elif isinstance(action, ActionScroll): + await computer.scroll(action.x, action.y, action.scroll_x, action.scroll_y) + elif isinstance(action, ActionType): + await computer.type(action.text) + elif isinstance(action, ActionWait): + await computer.wait() + except Exception as e: + raise ModelBehaviorError(f"Error executing computer action: {e}") return await computer.screenshot() diff --git a/src/agents/agent.py b/src/agents/agent.py index 61c0a896..33fe2839 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -132,28 +132,36 @@ def as_tool( async def run_agent(context: RunContextWrapper, input: str) -> str: from .run import Runner - output = await Runner.run( - starting_agent=self, - input=input, - context=context.context, - ) - if custom_output_extractor: - return await custom_output_extractor(output) - - return ItemHelpers.text_message_outputs(output.new_items) + try: + output = await Runner.run( + starting_agent=self, + input=input, + context=context.context, + ) + if custom_output_extractor: + return await custom_output_extractor(output) + + return ItemHelpers.text_message_outputs(output.new_items) + except Exception as e: + logger.error(f"Error running agent as tool: {e}") + raise return run_agent async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None: """Get the system prompt for the agent.""" - if isinstance(self.instructions, str): - return self.instructions - elif callable(self.instructions): - if inspect.iscoroutinefunction(self.instructions): - return await cast(Awaitable[str], self.instructions(run_context, self)) - else: - return cast(str, self.instructions(run_context, self)) - elif self.instructions is not None: - logger.error(f"Instructions must be a string or a function, got {self.instructions}") - - return None + try: + if isinstance(self.instructions, str): + return self.instructions + elif callable(self.instructions): + if inspect.iscoroutinefunction(self.instructions): + return await cast(Awaitable[str], self.instructions(run_context, self)) + else: + return cast(str, self.instructions(run_context, self)) + elif self.instructions is not None: + logger.error(f"Instructions must be a string or a function, got {self.instructions}") + + return None + except Exception as e: + logger.error(f"Error getting system prompt: {e}") + raise diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 0c28800f..d4bde2f8 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -87,31 +87,40 @@ def validate_json(self, json_str: str, partial: bool = False) -> Any: """Validate a JSON string against the output type. Returns the validated object, or raises a `ModelBehaviorError` if the JSON is invalid. """ - validated = _utils.validate_json(json_str, self._type_adapter, partial) - if self._is_wrapped: - if not isinstance(validated, dict): - _utils.attach_error_to_current_span( - SpanError( - message="Invalid JSON", - data={"details": f"Expected a dict, got {type(validated)}"}, + try: + validated = _utils.validate_json(json_str, self._type_adapter, partial) + if self._is_wrapped: + if not isinstance(validated, dict): + _utils.attach_error_to_current_span( + SpanError( + message="Invalid JSON", + data={"details": f"Expected a dict, got {type(validated)}"}, + ) + ) + raise ModelBehaviorError( + f"Expected a dict, got {type(validated)} for JSON: {json_str}" ) - ) - raise ModelBehaviorError( - f"Expected a dict, got {type(validated)} for JSON: {json_str}" - ) - if _WRAPPER_DICT_KEY not in validated: - _utils.attach_error_to_current_span( - SpanError( - message="Invalid JSON", - data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"}, + if _WRAPPER_DICT_KEY not in validated: + _utils.attach_error_to_current_span( + SpanError( + message="Invalid JSON", + data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"}, + ) ) + raise ModelBehaviorError( + f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}" + ) + return validated[_WRAPPER_DICT_KEY] + return validated + except Exception as e: + _utils.attach_error_to_current_span( + SpanError( + message="Error validating JSON", + data={"error": str(e)}, ) - raise ModelBehaviorError( - f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}" - ) - return validated[_WRAPPER_DICT_KEY] - return validated + ) + raise def output_type_name(self) -> str: """The name of the output type.""" @@ -119,6 +128,7 @@ def output_type_name(self) -> str: def _is_subclass_of_base_model_or_dict(t: Any) -> bool: + """Check if a type is a subclass of BaseModel or dict.""" if not isinstance(t, type): return False @@ -131,6 +141,7 @@ def _is_subclass_of_base_model_or_dict(t: Any) -> bool: def _type_to_str(t: type[Any]) -> str: + """Convert a type to its string representation.""" origin = get_origin(t) args = get_args(t) diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 78898f01..93f56d93 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -6,6 +6,7 @@ class AgentsException(Exception): """Base class for all exceptions in the Agents SDK.""" + pass class MaxTurnsExceeded(AgentsException): @@ -13,7 +14,7 @@ class MaxTurnsExceeded(AgentsException): message: str - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message @@ -24,7 +25,7 @@ class ModelBehaviorError(AgentsException): message: str - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message @@ -33,7 +34,7 @@ class UserError(AgentsException): message: str - def __init__(self, message: str): + def __init__(self, message: str) -> None: self.message = message @@ -43,7 +44,7 @@ class InputGuardrailTripwireTriggered(AgentsException): guardrail_result: "InputGuardrailResult" """The result data of the guardrail that was triggered.""" - def __init__(self, guardrail_result: "InputGuardrailResult"): + def __init__(self, guardrail_result: "InputGuardrailResult") -> None: self.guardrail_result = guardrail_result super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" @@ -56,7 +57,7 @@ class OutputGuardrailTripwireTriggered(AgentsException): guardrail_result: "OutputGuardrailResult" """The result data of the guardrail that was triggered.""" - def __init__(self, guardrail_result: "OutputGuardrailResult"): + def __init__(self, guardrail_result: "OutputGuardrailResult") -> None: self.guardrail_result = guardrail_result super().__init__( f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire" diff --git a/src/agents/function_schema.py b/src/agents/function_schema.py index a4b57672..1c1ebe27 100644 --- a/src/agents/function_schema.py +++ b/src/agents/function_schema.py @@ -87,6 +87,15 @@ class FuncDocumentation: # As of Feb 2025, the automatic style detection in griffe is an Insiders feature. This # code approximates it. def _detect_docstring_style(doc: str) -> DocstringStyle: + """ + Detects the style of a docstring. + + Args: + doc: The docstring to detect the style of. + + Returns: + The detected docstring style. + """ scores: dict[DocstringStyle, int] = {"sphinx": 0, "numpy": 0, "google": 0} # Sphinx style detection: look for :param, :type, :return:, and :rtype: @@ -128,6 +137,9 @@ def _detect_docstring_style(doc: str) -> DocstringStyle: @contextlib.contextmanager def _suppress_griffe_logging(): + """ + Context manager to suppress griffe logging. + """ # Supresses warnings about missing annotations for params logger = logging.getLogger("griffe") previous_level = logger.getEffectiveLevel() @@ -210,131 +222,133 @@ def function_schema( A `FuncSchema` object containing the function's name, description, parameter descriptions, and other metadata. """ - - # 1. Grab docstring info - if use_docstring_info: - doc_info = generate_func_documentation(func, docstring_style) - param_descs = doc_info.param_descriptions or {} - else: - doc_info = None - param_descs = {} - - func_name = name_override or doc_info.name if doc_info else func.__name__ - - # 2. Inspect function signature and get type hints - sig = inspect.signature(func) - type_hints = get_type_hints(func) - params = list(sig.parameters.items()) - takes_context = False - filtered_params = [] - - if params: - first_name, first_param = params[0] - # Prefer the evaluated type hint if available - ann = type_hints.get(first_name, first_param.annotation) - if ann != inspect._empty: - origin = get_origin(ann) or ann - if origin is RunContextWrapper: - takes_context = True # Mark that the function takes context - else: - filtered_params.append((first_name, first_param)) + try: + # 1. Grab docstring info + if use_docstring_info: + doc_info = generate_func_documentation(func, docstring_style) + param_descs = doc_info.param_descriptions or {} else: - filtered_params.append((first_name, first_param)) - - # For parameters other than the first, raise error if any use RunContextWrapper. - for name, param in params[1:]: - ann = type_hints.get(name, param.annotation) - if ann != inspect._empty: - origin = get_origin(ann) or ann - if origin is RunContextWrapper: - raise UserError( - f"RunContextWrapper param found at non-first position in function" - f" {func.__name__}" - ) - filtered_params.append((name, param)) - - # We will collect field definitions for create_model as a dict: - # field_name -> (type_annotation, default_value_or_Field(...)) - fields: dict[str, Any] = {} - - for name, param in filtered_params: - ann = type_hints.get(name, param.annotation) - default = param.default - - # If there's no type hint, assume `Any` - if ann == inspect._empty: - ann = Any - - # If a docstring param description exists, use it - field_description = param_descs.get(name, None) - - # Handle different parameter kinds - if param.kind == param.VAR_POSITIONAL: - # e.g. *args: extend positional args - if get_origin(ann) is tuple: - # e.g. def foo(*args: tuple[int, ...]) -> treat as List[int] - args_of_tuple = get_args(ann) - if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis: - ann = list[args_of_tuple[0]] # type: ignore + doc_info = None + param_descs = {} + + func_name = name_override or doc_info.name if doc_info else func.__name__ + + # 2. Inspect function signature and get type hints + sig = inspect.signature(func) + type_hints = get_type_hints(func) + params = list(sig.parameters.items()) + takes_context = False + filtered_params = [] + + if params: + first_name, first_param = params[0] + # Prefer the evaluated type hint if available + ann = type_hints.get(first_name, first_param.annotation) + if ann != inspect._empty: + origin = get_origin(ann) or ann + if origin is RunContextWrapper: + takes_context = True # Mark that the function takes context else: - ann = list[Any] + filtered_params.append((first_name, first_param)) else: - # If user wrote *args: int, treat as List[int] - ann = list[ann] # type: ignore - - # Default factory to empty list - fields[name] = ( - ann, - Field(default_factory=list, description=field_description), # type: ignore - ) - - elif param.kind == param.VAR_KEYWORD: - # **kwargs handling - if get_origin(ann) is dict: - # e.g. def foo(**kwargs: dict[str, int]) - dict_args = get_args(ann) - if len(dict_args) == 2: - ann = dict[dict_args[0], dict_args[1]] # type: ignore - else: - ann = dict[str, Any] - else: - # e.g. def foo(**kwargs: int) -> Dict[str, int] - ann = dict[str, ann] # type: ignore + filtered_params.append((first_name, first_param)) - fields[name] = ( - ann, - Field(default_factory=dict, description=field_description), # type: ignore - ) + # For parameters other than the first, raise error if any use RunContextWrapper. + for name, param in params[1:]: + ann = type_hints.get(name, param.annotation) + if ann != inspect._empty: + origin = get_origin(ann) or ann + if origin is RunContextWrapper: + raise UserError( + f"RunContextWrapper param found at non-first position in function" + f" {func.__name__}" + ) + filtered_params.append((name, param)) + + # We will collect field definitions for create_model as a dict: + # field_name -> (type_annotation, default_value_or_Field(...)) + fields: dict[str, Any] = {} + + for name, param in filtered_params: + ann = type_hints.get(name, param.annotation) + default = param.default + + # If there's no type hint, assume `Any` + if ann == inspect._empty: + ann = Any + + # If a docstring param description exists, use it + field_description = param_descs.get(name, None) + + # Handle different parameter kinds + if param.kind == param.VAR_POSITIONAL: + # e.g. *args: extend positional args + if get_origin(ann) is tuple: + # e.g. def foo(*args: tuple[int, ...]) -> treat as List[int] + args_of_tuple = get_args(ann) + if len(args_of_tuple) == 2 and args_of_tuple[1] is Ellipsis: + ann = list[args_of_tuple[0]] # type: ignore + else: + ann = list[Any] + else: + # If user wrote *args: int, treat as List[int] + ann = list[ann] # type: ignore - else: - # Normal parameter - if default == inspect._empty: - # Required field + # Default factory to empty list fields[name] = ( ann, - Field(..., description=field_description), + Field(default_factory=list, description=field_description), # type: ignore ) - else: - # Parameter with a default value + + elif param.kind == param.VAR_KEYWORD: + # **kwargs handling + if get_origin(ann) is dict: + # e.g. def foo(**kwargs: dict[str, int]) + dict_args = get_args(ann) + if len(dict_args) == 2: + ann = dict[dict_args[0], dict_args[1]] # type: ignore + else: + ann = dict[str, Any] + else: + # e.g. def foo(**kwargs: int) -> Dict[str, int] + ann = dict[str, ann] # type: ignore + fields[name] = ( ann, - Field(default=default, description=field_description), + Field(default_factory=dict, description=field_description), # type: ignore ) - # 3. Dynamically build a Pydantic model - dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields) - - # 4. Build JSON schema from that model - json_schema = dynamic_model.model_json_schema() - if strict_json_schema: - json_schema = ensure_strict_json_schema(json_schema) - - # 5. Return as a FuncSchema dataclass - return FuncSchema( - name=func_name, - description=description_override or doc_info.description if doc_info else None, - params_pydantic_model=dynamic_model, - params_json_schema=json_schema, - signature=sig, - takes_context=takes_context, - ) + else: + # Normal parameter + if default == inspect._empty: + # Required field + fields[name] = ( + ann, + Field(..., description=field_description), + ) + else: + # Parameter with a default value + fields[name] = ( + ann, + Field(default=default, description=field_description), + ) + + # 3. Dynamically build a Pydantic model + dynamic_model = create_model(f"{func_name}_args", __base__=BaseModel, **fields) + + # 4. Build JSON schema from that model + json_schema = dynamic_model.model_json_schema() + if strict_json_schema: + json_schema = ensure_strict_json_schema(json_schema) + + # 5. Return as a FuncSchema dataclass + return FuncSchema( + name=func_name, + description=description_override or doc_info.description if doc_info else None, + params_pydantic_model=dynamic_model, + params_json_schema=json_schema, + signature=sig, + takes_context=takes_context, + ) + except Exception as e: + raise UserError(f"Error generating function schema: {e}") from e diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 5bebcd66..d58063d8 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -97,6 +97,7 @@ class InputGuardrail(Generic[TContext]): """ def get_name(self) -> str: + """Get the name of the guardrail.""" if self.name: return self.name @@ -108,20 +109,24 @@ async def run( input: str | list[TResponseInputItem], context: RunContextWrapper[TContext], ) -> InputGuardrailResult: + """Run the input guardrail.""" if not callable(self.guardrail_function): raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") - output = self.guardrail_function(context, agent, input) - if inspect.isawaitable(output): + try: + output = self.guardrail_function(context, agent, input) + if inspect.isawaitable(output): + return InputGuardrailResult( + guardrail=self, + output=await output, + ) + return InputGuardrailResult( guardrail=self, - output=await output, + output=output, ) - - return InputGuardrailResult( - guardrail=self, - output=output, - ) + except Exception as e: + raise UserError(f"Error running input guardrail: {e}") from e @dataclass @@ -151,6 +156,7 @@ class OutputGuardrail(Generic[TContext]): """ def get_name(self) -> str: + """Get the name of the guardrail.""" if self.name: return self.name @@ -159,24 +165,28 @@ def get_name(self) -> str: async def run( self, context: RunContextWrapper[TContext], agent: Agent[Any], agent_output: Any ) -> OutputGuardrailResult: + """Run the output guardrail.""" if not callable(self.guardrail_function): raise UserError(f"Guardrail function must be callable, got {self.guardrail_function}") - output = self.guardrail_function(context, agent, agent_output) - if inspect.isawaitable(output): + try: + output = self.guardrail_function(context, agent, agent_output) + if inspect.isawaitable(output): + return OutputGuardrailResult( + guardrail=self, + agent=agent, + agent_output=agent_output, + output=await output, + ) + return OutputGuardrailResult( guardrail=self, agent=agent, agent_output=agent_output, - output=await output, + output=output, ) - - return OutputGuardrailResult( - guardrail=self, - agent=agent, - agent_output=agent_output, - output=output, - ) + except Exception as e: + raise UserError(f"Error running output guardrail: {e}") from e TContext_co = TypeVar("TContext_co", bound=Any, covariant=True) diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index ac157401..fd25e912 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -28,6 +28,10 @@ @dataclass(frozen=True) class HandoffInputData: + """ + The input data passed to the next agent during a handoff. + """ + input_history: str | tuple[TResponseInputItem, ...] """ The input history before `Runner.run()` was called. @@ -51,7 +55,8 @@ class HandoffInputData: @dataclass class Handoff(Generic[TContext]): - """A handoff is when an agent delegates a task to another agent. + """ + A handoff is when an agent delegates a task to another agent. For example, in a customer support scenario you might have a "triage agent" that determines which agent should handle the user's request, and sub-agents that specialize in different areas like billing, account management, etc. @@ -99,15 +104,42 @@ class Handoff(Generic[TContext]): """ def get_transfer_message(self, agent: Agent[Any]) -> str: + """ + Get the transfer message for the handoff. + + Args: + agent: The agent that is being handed off to. + + Returns: + The transfer message. + """ base = f"{{'assistant': '{agent.name}'}}" return base @classmethod def default_tool_name(cls, agent: Agent[Any]) -> str: + """ + Get the default tool name for the handoff. + + Args: + agent: The agent that is being handed off to. + + Returns: + The default tool name. + """ return _utils.transform_string_function_style(f"transfer_to_{agent.name}") @classmethod def default_tool_description(cls, agent: Agent[Any]) -> str: + """ + Get the default tool description for the handoff. + + Args: + agent: The agent that is being handed off to. + + Returns: + The default tool description. + """ return ( f"Handoff to the {agent.name} agent to handle the request. " f"{agent.handoff_description or ''}" @@ -190,34 +222,43 @@ def handoff( async def _invoke_handoff( ctx: RunContextWrapper[Any], input_json: str | None = None ) -> Agent[Any]: - if input_type is not None and type_adapter is not None: - if input_json is None: - _utils.attach_error_to_current_span( - SpanError( - message="Handoff function expected non-null input, but got None", - data={"details": "input_json is None"}, + try: + if input_type is not None and type_adapter is not None: + if input_json is None: + _utils.attach_error_to_current_span( + SpanError( + message="Handoff function expected non-null input, but got None", + data={"details": "input_json is None"}, + ) ) - ) - raise ModelBehaviorError("Handoff function expected non-null input, but got None") + raise ModelBehaviorError("Handoff function expected non-null input, but got None") - validated_input = _utils.validate_json( - json_str=input_json, - type_adapter=type_adapter, - partial=False, + validated_input = _utils.validate_json( + json_str=input_json, + type_adapter=type_adapter, + partial=False, + ) + input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff) + if inspect.iscoroutinefunction(input_func): + await input_func(ctx, validated_input) + else: + input_func(ctx, validated_input) + elif on_handoff is not None: + no_input_func = cast(OnHandoffWithoutInput, on_handoff) + if inspect.iscoroutinefunction(no_input_func): + await no_input_func(ctx) + else: + no_input_func(ctx) + + return agent + except Exception as e: + _utils.attach_error_to_current_span( + SpanError( + message="Error invoking handoff", + data={"error": str(e)}, + ) ) - input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff) - if inspect.iscoroutinefunction(input_func): - await input_func(ctx, validated_input) - else: - input_func(ctx, validated_input) - elif on_handoff is not None: - no_input_func = cast(OnHandoffWithoutInput, on_handoff) - if inspect.iscoroutinefunction(no_input_func): - await no_input_func(ctx) - else: - no_input_func(ctx) - - return agent + raise tool_name = tool_name_override or Handoff.default_tool_name(agent) tool_description = tool_description_override or Handoff.default_tool_description(agent) diff --git a/src/agents/models/gemini_chatcompletions.py b/src/agents/models/gemini_chatcompletions.py new file mode 100644 index 00000000..70ed81d1 --- /dev/null +++ b/src/agents/models/gemini_chatcompletions.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from openai import AsyncOpenAI +from openai.types.chat import ChatModel + +from .openai_chatcompletions import OpenAIChatCompletionsModel + + +class GeminiChatCompletionsModel(OpenAIChatCompletionsModel): + """ + Model implementation for Google Gemini using the OpenAI-compatible API endpoint. + + This class extends the OpenAIChatCompletionsModel since Google's OpenAI-compatible + endpoint follows the same interface as OpenAI's Chat Completions API. + """ + + def __init__( + self, + model: str | ChatModel, + openai_client: AsyncOpenAI, + ) -> None: + super().__init__(model=model, openai_client=openai_client) \ No newline at end of file diff --git a/src/agents/models/gemini_provider.py b/src/agents/models/gemini_provider.py new file mode 100644 index 00000000..d5cf03d2 --- /dev/null +++ b/src/agents/models/gemini_provider.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import httpx +from openai import AsyncOpenAI, DefaultAsyncHttpxClient + +from .interface import Model, ModelProvider +from .gemini_chatcompletions import GeminiChatCompletionsModel + +DEFAULT_MODEL: str = "gemini-2.0-flash" + +_http_client: httpx.AsyncClient | None = None + +# If we create a new httpx client for each request, that would mean no sharing of connection pools, +# which would mean worse latency and resource usage. So, we share the client across requests. +def shared_http_client() -> httpx.AsyncClient: + global _http_client + if _http_client is None: + _http_client = DefaultAsyncHttpxClient() + return _http_client + + +class GeminiProvider(ModelProvider): + """ + Model provider for Google Gemini models. + + Uses Google's OpenAI-compatible API endpoint to integrate with Gemini models. + """ + + def __init__( + self, + *, + api_key: str | None = None, + base_url: str | None = "https://generativelanguage.googleapis.com/v1beta/openai/", + openai_client: AsyncOpenAI | None = None, + default_model: str = DEFAULT_MODEL, + ) -> None: + if openai_client is not None: + assert api_key is None and base_url is None, ( + "Don't provide api_key or base_url if you provide openai_client" + ) + self._client: AsyncOpenAI | None = openai_client + else: + self._client = None + self._stored_api_key = api_key + self._stored_base_url = base_url + + self._default_model = default_model + + # We lazy load the client in case you never actually use GeminiProvider() + def _get_client(self) -> AsyncOpenAI: + if self._client is None: + self._client = AsyncOpenAI( + api_key=self._stored_api_key, + base_url=self._stored_base_url, + http_client=shared_http_client(), + ) + + return self._client + + def get_model(self, model_name: str | None) -> Model: + if model_name is None: + model_name = self._default_model + + client = self._get_client() + + return GeminiChatCompletionsModel(model=model_name, openai_client=client) \ No newline at end of file diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 3543225d..778ea21d 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -115,17 +115,21 @@ async def get_response( | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: - response = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - span_generation, - tracing, - stream=False, - ) + try: + response = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=False, + ) + except Exception as e: + logger.error(f"Error fetching response: {e}") + raise if _debug.DONT_LOG_MODEL_DATA: logger.debug("Received model response") @@ -178,17 +182,21 @@ async def stream_response( | {"base_url": str(self._client.base_url)}, disabled=tracing.is_disabled(), ) as span_generation: - response, stream = await self._fetch_response( - system_instructions, - input, - model_settings, - tools, - output_schema, - handoffs, - span_generation, - tracing, - stream=True, - ) + try: + response, stream = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=True, + ) + except Exception as e: + logger.error(f"Error fetching response: {e}") + raise usage: CompletionUsage | None = None state = _StreamingState() @@ -513,22 +521,26 @@ async def _fetch_response( f"Response format: {response_format}\n" ) - ret = await self._get_client().chat.completions.create( - model=self.model, - messages=converted_messages, - tools=converted_tools or NOT_GIVEN, - temperature=self._non_null_or_not_given(model_settings.temperature), - top_p=self._non_null_or_not_given(model_settings.top_p), - frequency_penalty=self._non_null_or_not_given(model_settings.frequency_penalty), - presence_penalty=self._non_null_or_not_given(model_settings.presence_penalty), - max_tokens=self._non_null_or_not_given(model_settings.max_tokens), - tool_choice=tool_choice, - response_format=response_format, - parallel_tool_calls=parallel_tool_calls, - stream=stream, - stream_options={"include_usage": True} if stream else NOT_GIVEN, - extra_headers=_HEADERS, - ) + try: + ret = await self._get_client().chat.completions.create( + model=self.model, + messages=converted_messages, + tools=converted_tools or NOT_GIVEN, + temperature=self._non_null_or_not_given(model_settings.temperature), + top_p=self._non_null_or_not_given(model_settings.top_p), + frequency_penalty=self._non_null_or_not_given(model_settings.frequency_penalty), + presence_penalty=self._non_null_or_not_given(model_settings.presence_penalty), + max_tokens=self._non_null_or_not_given(model_settings.max_tokens), + tool_choice=tool_choice, + response_format=response_format, + parallel_tool_calls=parallel_tool_calls, + stream=stream, + stream_options={"include_usage": True} if stream else NOT_GIVEN, + extra_headers=_HEADERS, + ) + except Exception as e: + logger.error(f"Error creating chat completion: {e}") + raise if isinstance(ret, ChatCompletion): return ret diff --git a/src/agents/models/openai_provider.py b/src/agents/models/openai_provider.py index 51946638..e6a859fb 100644 --- a/src/agents/models/openai_provider.py +++ b/src/agents/models/openai_provider.py @@ -38,28 +38,41 @@ def __init__( assert api_key is None and base_url is None, ( "Don't provide api_key or base_url if you provide openai_client" ) - self._client = openai_client + self._client: AsyncOpenAI | None = openai_client else: - self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI( - api_key=api_key or _openai_shared.get_default_openai_key(), - base_url=base_url, - organization=organization, - project=project, - http_client=shared_http_client(), - ) + self._client = None + self._stored_api_key = api_key + self._stored_base_url = base_url + self._stored_organization = organization + self._stored_project = project - self._is_openai_model = self._client.base_url.host.startswith("api.openai.com") if use_responses is not None: self._use_responses = use_responses else: self._use_responses = _openai_shared.get_use_responses_by_default() + # We lazy load the client in case you never actually use OpenAIProvider(). Otherwise + # AsyncOpenAI() raises an error if you don't have an API key set. + def _get_client(self) -> AsyncOpenAI: + if self._client is None: + self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI( + api_key=self._stored_api_key or _openai_shared.get_default_openai_key(), + base_url=self._stored_base_url, + organization=self._stored_organization, + project=self._stored_project, + http_client=shared_http_client(), + ) + + return self._client + def get_model(self, model_name: str | None) -> Model: if model_name is None: model_name = DEFAULT_MODEL + client = self._get_client() + return ( - OpenAIResponsesModel(model=model_name, openai_client=self._client) + OpenAIResponsesModel(model=model_name, openai_client=client) if self._use_responses - else OpenAIChatCompletionsModel(model=model_name, openai_client=self._client) + else OpenAIChatCompletionsModel(model=model_name, openai_client=client) ) diff --git a/src/agents/tracing/create.py b/src/agents/tracing/create.py index 8d7fc493..78a064bc 100644 --- a/src/agents/tracing/create.py +++ b/src/agents/tracing/create.py @@ -3,7 +3,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from .logger import logger +from ..logger import logger from .setup import GLOBAL_TRACE_PROVIDER from .span_data import ( AgentSpanData, diff --git a/src/agents/tracing/processors.py b/src/agents/tracing/processors.py index 308adf2a..1b39deda 100644 --- a/src/agents/tracing/processors.py +++ b/src/agents/tracing/processors.py @@ -9,7 +9,7 @@ import httpx -from .logger import logger +from ..logger import logger from .processor_interface import TracingExporter, TracingProcessor from .spans import Span from .traces import Trace @@ -40,7 +40,7 @@ def __init__( """ Args: api_key: The API key for the "Authorization" header. Defaults to - `os.environ["OPENAI_TRACE_API_KEY"]` if not provided. + `os.environ["OPENAI_API_KEY"]` if not provided. organization: The OpenAI organization to use. Defaults to `os.environ["OPENAI_ORG_ID"]` if not provided. project: The OpenAI project to use. Defaults to diff --git a/src/agents/tracing/scope.py b/src/agents/tracing/scope.py index 9ccd9f87..513ca8c0 100644 --- a/src/agents/tracing/scope.py +++ b/src/agents/tracing/scope.py @@ -2,7 +2,7 @@ import contextvars from typing import TYPE_CHECKING, Any -from .logger import logger +from ..logger import logger if TYPE_CHECKING: from .spans import Span diff --git a/src/agents/tracing/setup.py b/src/agents/tracing/setup.py index bc340c9f..3a7c6ade 100644 --- a/src/agents/tracing/setup.py +++ b/src/agents/tracing/setup.py @@ -4,8 +4,8 @@ import threading from typing import Any +from ..logger import logger from . import util -from .logger import logger from .processor_interface import TracingProcessor from .scope import Scope from .spans import NoOpSpan, Span, SpanImpl, TSpanData diff --git a/src/agents/tracing/spans.py b/src/agents/tracing/spans.py index d682a9a0..ee933e73 100644 --- a/src/agents/tracing/spans.py +++ b/src/agents/tracing/spans.py @@ -6,8 +6,8 @@ from typing_extensions import TypedDict +from ..logger import logger from . import util -from .logger import logger from .processor_interface import TracingProcessor from .scope import Scope from .span_data import SpanData diff --git a/src/agents/tracing/traces.py b/src/agents/tracing/traces.py index bf3b43df..53d06284 100644 --- a/src/agents/tracing/traces.py +++ b/src/agents/tracing/traces.py @@ -4,8 +4,8 @@ import contextvars from typing import Any +from ..logger import logger from . import util -from .logger import logger from .processor_interface import TracingProcessor from .scope import Scope diff --git a/uv.lock b/uv.lock index 9179bd4f..c3af99bd 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.9" [[package]] @@ -783,7 +782,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.0.3" +version = "0.0.4" source = { editable = "." } dependencies = [ { name = "griffe" },