From 25a633139dddef3e0bccd44d0553cc3d1a277072 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 12 Mar 2025 16:06:17 -0700 Subject: [PATCH] Add examples and documentation for using custom model providers --- docs/models.md | 25 +++---- examples/model_providers/README.md | 19 +++++ .../model_providers/custom_example_agent.py | 51 +++++++++++++ .../model_providers/custom_example_global.py | 55 ++++++++++++++ .../custom_example_provider.py | 73 +++++++++++++++++++ src/agents/__init__.py | 14 +++- src/agents/_config.py | 9 ++- src/agents/models/openai_provider.py | 35 ++++++--- 8 files changed, 247 insertions(+), 34 deletions(-) create mode 100644 examples/model_providers/README.md create mode 100644 examples/model_providers/custom_example_agent.py create mode 100644 examples/model_providers/custom_example_global.py create mode 100644 examples/model_providers/custom_example_provider.py diff --git a/docs/models.md b/docs/models.md index 7ad515bc..a4801fee 100644 --- a/docs/models.md +++ b/docs/models.md @@ -53,21 +53,14 @@ async def main(): ## Using other LLM providers -Many providers also support the OpenAI API format, which means you can pass a `base_url` to the existing OpenAI model implementations and use them easily. `ModelSettings` is used to configure tuning parameters (e.g., temperature, top_p) for the model you select. +You can use other LLM providers in 3 ways (examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/)): -```python -external_client = AsyncOpenAI( - api_key="EXTERNAL_API_KEY", - base_url="https://api.external.com/v1/", -) +1. [`set_default_openai_client`][agents.set_default_openai_client] is useful in cases where you want to globally use an instance of `AsyncOpenAI` as the LLM client. This is for cases where the LLM provider has an OpenAI compatible API endpoint, and you can set the `base_url` and `api_key`. See a configurable example in [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py). +2. [`ModelProvider`][agents.models.interface.ModelProvider] is at the `Runner.run` level. This lets you say "use a custom model provider for all agents in this run". See a configurable example in [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py). +3. [`Agent.model`][agents.agent.Agent.model] lets you specify the model on a specific Agent instance. This enables you to mix and match different providers for different agents. See a configurable example in [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py). -spanish_agent = Agent( - name="Spanish agent", - instructions="You only speak Spanish.", - model=OpenAIChatCompletionsModel( - model="EXTERNAL_MODEL_NAME", - openai_client=external_client, - ), - model_settings=ModelSettings(temperature=0.5), -) -``` +In cases where you do not have an API key from `platform.openai.com`, we recommend disabling tracing via `set_tracing_disabled()`, or setting up a [different tracing processor](tracing.md). + +!!! note + + In these examples, we use the Chat Completions API/model, because most LLM providers don't yet support the Responses API. If your LLM provider does support it, we recommend using Responses. diff --git a/examples/model_providers/README.md b/examples/model_providers/README.md new file mode 100644 index 00000000..f9330c24 --- /dev/null +++ b/examples/model_providers/README.md @@ -0,0 +1,19 @@ +# Custom LLM providers + +The examples in this directory demonstrate how you might use a non-OpenAI LLM provider. To run them, first set a base URL, API key and model. + +```bash +export EXAMPLE_BASE_URL="..." +export EXAMPLE_API_KEY="..." +export EXAMPLE_MODEL_NAME"..." +``` + +Then run the examples, e.g.: + +``` +python examples/model_providers/custom_example_provider.py + +Loops within themselves, +Function calls its own being, +Depth without ending. +``` diff --git a/examples/model_providers/custom_example_agent.py b/examples/model_providers/custom_example_agent.py new file mode 100644 index 00000000..d7519a52 --- /dev/null +++ b/examples/model_providers/custom_example_agent.py @@ -0,0 +1,51 @@ +import asyncio +import os + +from openai import AsyncOpenAI + +from agents import Agent, OpenAIChatCompletionsModel, Runner, set_tracing_disabled + +BASE_URL = os.getenv("EXAMPLE_BASE_URL") or "" +API_KEY = os.getenv("EXAMPLE_API_KEY") or "" +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "" + +if not BASE_URL or not API_KEY or not MODEL_NAME: + raise ValueError( + "Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code." + ) + +"""This example uses a custom provider for a specific agent. Steps: +1. Create a custom OpenAI client. +2. Create a `Model` that uses the custom client. +3. Set the `model` on the Agent. + +Note that in this example, we disable tracing under the assumption that you don't have an API key +from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var +or call set_tracing_export_api_key() to set a tracing specific key. +""" +client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY) +set_tracing_disabled(disabled=True) + +# An alternate approach that would also work: +# PROVIDER = OpenAIProvider(openai_client=client) +# agent = Agent(..., model="some-custom-model") +# Runner.run(agent, ..., run_config=RunConfig(model_provider=PROVIDER)) + + +async def main(): + # This agent will use the custom LLM provider + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client), + ) + + result = await Runner.run( + agent, + "Tell me about recursion in programming.", + ) + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/model_providers/custom_example_global.py b/examples/model_providers/custom_example_global.py new file mode 100644 index 00000000..d7c293b1 --- /dev/null +++ b/examples/model_providers/custom_example_global.py @@ -0,0 +1,55 @@ +import asyncio +import os + +from openai import AsyncOpenAI + +from agents import ( + Agent, + Runner, + set_default_openai_api, + set_default_openai_client, + set_tracing_disabled, +) + +BASE_URL = os.getenv("EXAMPLE_BASE_URL") or "" +API_KEY = os.getenv("EXAMPLE_API_KEY") or "" +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "" + +if not BASE_URL or not API_KEY or not MODEL_NAME: + raise ValueError( + "Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code." + ) + + +"""This example uses a custom provider for all requests by default. We do three things: +1. Create a custom client. +2. Set it as the default OpenAI client, and don't use it for tracing. +3. Set the default API as Chat Completions, as most LLM providers don't yet support Responses API. + +Note that in this example, we disable tracing under the assumption that you don't have an API key +from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var +or call set_tracing_export_api_key() to set a tracing specific key. +""" + +client = AsyncOpenAI( + base_url=BASE_URL, + api_key=API_KEY, +) +set_default_openai_client(client=client, use_for_tracing=False) +set_default_openai_api("chat_completions") +set_tracing_disabled(disabled=True) + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + model=MODEL_NAME, + ) + + result = await Runner.run(agent, "Tell me about recursion in programming.") + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/model_providers/custom_example_provider.py b/examples/model_providers/custom_example_provider.py new file mode 100644 index 00000000..6e8af42e --- /dev/null +++ b/examples/model_providers/custom_example_provider.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import asyncio +import os + +from openai import AsyncOpenAI + +from agents import ( + Agent, + Model, + ModelProvider, + OpenAIChatCompletionsModel, + RunConfig, + Runner, + set_tracing_disabled, +) + +BASE_URL = os.getenv("EXAMPLE_BASE_URL") or "" +API_KEY = os.getenv("EXAMPLE_API_KEY") or "" +MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or "" + +if not BASE_URL or not API_KEY or not MODEL_NAME: + raise ValueError( + "Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code." + ) + + +"""This example uses a custom provider for some calls to Runner.run(), and direct calls to OpenAI for +others. Steps: +1. Create a custom OpenAI client. +2. Create a ModelProvider that uses the custom client. +3. Use the ModelProvider in calls to Runner.run(), only when we want to use the custom LLM provider. + +Note that in this example, we disable tracing under the assumption that you don't have an API key +from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var +or call set_tracing_export_api_key() to set a tracing specific key. +""" +client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY) +set_tracing_disabled(disabled=True) + + +class CustomModelProvider(ModelProvider): + def get_model(self, model_name: str | None) -> Model: + return OpenAIChatCompletionsModel(model=model_name or MODEL_NAME, openai_client=client) + + +CUSTOM_MODEL_PROVIDER = CustomModelProvider() + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + ) + + # This will use the custom model provider + result = await Runner.run( + agent, + "Tell me about recursion in programming.", + run_config=RunConfig(model_provider=CUSTOM_MODEL_PROVIDER), + ) + print(result.final_output) + + # If you uncomment this, it will use OpenAI directly, not the custom provider + # result = await Runner.run( + # agent, + # "Tell me about recursion in programming.", + # ) + # print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 69c500ab..79940fe0 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -92,13 +92,19 @@ from .usage import Usage -def set_default_openai_key(key: str) -> None: - """Set the default OpenAI API key to use for LLM requests and tracing. This is only necessary if - the OPENAI_API_KEY environment variable is not already set. +def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None: + """Set the default OpenAI API key to use for LLM requests (and optionally tracing(). This is + only necessary if the OPENAI_API_KEY environment variable is not already set. If provided, this key will be used instead of the OPENAI_API_KEY environment variable. + + Args: + key: The OpenAI key to use. + use_for_tracing: Whether to also use this key to send traces to OpenAI. Defaults to True + If False, you'll either need to set the OPENAI_API_KEY environment variable or call + set_tracing_export_api_key() with the API key you want to use for tracing. """ - _config.set_default_openai_key(key) + _config.set_default_openai_key(key, use_for_tracing) def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None: diff --git a/src/agents/_config.py b/src/agents/_config.py index 55ded64d..304cfb83 100644 --- a/src/agents/_config.py +++ b/src/agents/_config.py @@ -5,15 +5,18 @@ from .tracing import set_tracing_export_api_key -def set_default_openai_key(key: str) -> None: - set_tracing_export_api_key(key) +def set_default_openai_key(key: str, use_for_tracing: bool) -> None: _openai_shared.set_default_openai_key(key) + if use_for_tracing: + set_tracing_export_api_key(key) + def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool) -> None: + _openai_shared.set_default_openai_client(client) + if use_for_tracing: set_tracing_export_api_key(client.api_key) - _openai_shared.set_default_openai_client(client) def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None: diff --git a/src/agents/models/openai_provider.py b/src/agents/models/openai_provider.py index 51946638..e6a859fb 100644 --- a/src/agents/models/openai_provider.py +++ b/src/agents/models/openai_provider.py @@ -38,28 +38,41 @@ def __init__( assert api_key is None and base_url is None, ( "Don't provide api_key or base_url if you provide openai_client" ) - self._client = openai_client + self._client: AsyncOpenAI | None = openai_client else: - self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI( - api_key=api_key or _openai_shared.get_default_openai_key(), - base_url=base_url, - organization=organization, - project=project, - http_client=shared_http_client(), - ) + self._client = None + self._stored_api_key = api_key + self._stored_base_url = base_url + self._stored_organization = organization + self._stored_project = project - self._is_openai_model = self._client.base_url.host.startswith("api.openai.com") if use_responses is not None: self._use_responses = use_responses else: self._use_responses = _openai_shared.get_use_responses_by_default() + # We lazy load the client in case you never actually use OpenAIProvider(). Otherwise + # AsyncOpenAI() raises an error if you don't have an API key set. + def _get_client(self) -> AsyncOpenAI: + if self._client is None: + self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI( + api_key=self._stored_api_key or _openai_shared.get_default_openai_key(), + base_url=self._stored_base_url, + organization=self._stored_organization, + project=self._stored_project, + http_client=shared_http_client(), + ) + + return self._client + def get_model(self, model_name: str | None) -> Model: if model_name is None: model_name = DEFAULT_MODEL + client = self._get_client() + return ( - OpenAIResponsesModel(model=model_name, openai_client=self._client) + OpenAIResponsesModel(model=model_name, openai_client=client) if self._use_responses - else OpenAIChatCompletionsModel(model=model_name, openai_client=self._client) + else OpenAIChatCompletionsModel(model=model_name, openai_client=client) )