Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add examples and documentation for using custom model providers #110

Merged
merged 1 commit into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,14 @@ async def main():

## Using other LLM providers

Many providers also support the OpenAI API format, which means you can pass a `base_url` to the existing OpenAI model implementations and use them easily. `ModelSettings` is used to configure tuning parameters (e.g., temperature, top_p) for the model you select.
You can use other LLM providers in 3 ways (examples [here](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/)):

```python
external_client = AsyncOpenAI(
api_key="EXTERNAL_API_KEY",
base_url="https://api.external.com/v1/",
)
1. [`set_default_openai_client`][agents.set_default_openai_client] is useful in cases where you want to globally use an instance of `AsyncOpenAI` as the LLM client. This is for cases where the LLM provider has an OpenAI compatible API endpoint, and you can set the `base_url` and `api_key`. See a configurable example in [examples/model_providers/custom_example_global.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_global.py).
2. [`ModelProvider`][agents.models.interface.ModelProvider] is at the `Runner.run` level. This lets you say "use a custom model provider for all agents in this run". See a configurable example in [examples/model_providers/custom_example_provider.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_provider.py).
3. [`Agent.model`][agents.agent.Agent.model] lets you specify the model on a specific Agent instance. This enables you to mix and match different providers for different agents. See a configurable example in [examples/model_providers/custom_example_agent.py](https://github.com/openai/openai-agents-python/tree/main/examples/model_providers/custom_example_agent.py).

spanish_agent = Agent(
name="Spanish agent",
instructions="You only speak Spanish.",
model=OpenAIChatCompletionsModel(
model="EXTERNAL_MODEL_NAME",
openai_client=external_client,
),
model_settings=ModelSettings(temperature=0.5),
)
```
In cases where you do not have an API key from `platform.openai.com`, we recommend disabling tracing via `set_tracing_disabled()`, or setting up a [different tracing processor](tracing.md).

!!! note

In these examples, we use the Chat Completions API/model, because most LLM providers don't yet support the Responses API. If your LLM provider does support it, we recommend using Responses.
19 changes: 19 additions & 0 deletions examples/model_providers/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Custom LLM providers

The examples in this directory demonstrate how you might use a non-OpenAI LLM provider. To run them, first set a base URL, API key and model.

```bash
export EXAMPLE_BASE_URL="..."
export EXAMPLE_API_KEY="..."
export EXAMPLE_MODEL_NAME"..."
```

Then run the examples, e.g.:

```
python examples/model_providers/custom_example_provider.py

Loops within themselves,
Function calls its own being,
Depth without ending.
```
51 changes: 51 additions & 0 deletions examples/model_providers/custom_example_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import asyncio
import os

from openai import AsyncOpenAI

from agents import Agent, OpenAIChatCompletionsModel, Runner, set_tracing_disabled

BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""

if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)

"""This example uses a custom provider for a specific agent. Steps:
1. Create a custom OpenAI client.
2. Create a `Model` that uses the custom client.
3. Set the `model` on the Agent.

Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""
client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)
set_tracing_disabled(disabled=True)

# An alternate approach that would also work:
# PROVIDER = OpenAIProvider(openai_client=client)
# agent = Agent(..., model="some-custom-model")
# Runner.run(agent, ..., run_config=RunConfig(model_provider=PROVIDER))


async def main():
# This agent will use the custom LLM provider
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
model=OpenAIChatCompletionsModel(model=MODEL_NAME, openai_client=client),
)

result = await Runner.run(
agent,
"Tell me about recursion in programming.",
)
print(result.final_output)


if __name__ == "__main__":
asyncio.run(main())
55 changes: 55 additions & 0 deletions examples/model_providers/custom_example_global.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import asyncio
import os

from openai import AsyncOpenAI

from agents import (
Agent,
Runner,
set_default_openai_api,
set_default_openai_client,
set_tracing_disabled,
)

BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""

if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)


"""This example uses a custom provider for all requests by default. We do three things:
1. Create a custom client.
2. Set it as the default OpenAI client, and don't use it for tracing.
3. Set the default API as Chat Completions, as most LLM providers don't yet support Responses API.

Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""

client = AsyncOpenAI(
base_url=BASE_URL,
api_key=API_KEY,
)
set_default_openai_client(client=client, use_for_tracing=False)
set_default_openai_api("chat_completions")
set_tracing_disabled(disabled=True)


async def main():
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
model=MODEL_NAME,
)

result = await Runner.run(agent, "Tell me about recursion in programming.")
print(result.final_output)


if __name__ == "__main__":
asyncio.run(main())
73 changes: 73 additions & 0 deletions examples/model_providers/custom_example_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

import asyncio
import os

from openai import AsyncOpenAI

from agents import (
Agent,
Model,
ModelProvider,
OpenAIChatCompletionsModel,
RunConfig,
Runner,
set_tracing_disabled,
)

BASE_URL = os.getenv("EXAMPLE_BASE_URL") or ""
API_KEY = os.getenv("EXAMPLE_API_KEY") or ""
MODEL_NAME = os.getenv("EXAMPLE_MODEL_NAME") or ""

if not BASE_URL or not API_KEY or not MODEL_NAME:
raise ValueError(
"Please set EXAMPLE_BASE_URL, EXAMPLE_API_KEY, EXAMPLE_MODEL_NAME via env var or code."
)


"""This example uses a custom provider for some calls to Runner.run(), and direct calls to OpenAI for
others. Steps:
1. Create a custom OpenAI client.
2. Create a ModelProvider that uses the custom client.
3. Use the ModelProvider in calls to Runner.run(), only when we want to use the custom LLM provider.

Note that in this example, we disable tracing under the assumption that you don't have an API key
from platform.openai.com. If you do have one, you can either set the `OPENAI_API_KEY` env var
or call set_tracing_export_api_key() to set a tracing specific key.
"""
client = AsyncOpenAI(base_url=BASE_URL, api_key=API_KEY)
set_tracing_disabled(disabled=True)


class CustomModelProvider(ModelProvider):
def get_model(self, model_name: str | None) -> Model:
return OpenAIChatCompletionsModel(model=model_name or MODEL_NAME, openai_client=client)


CUSTOM_MODEL_PROVIDER = CustomModelProvider()


async def main():
agent = Agent(
name="Assistant",
instructions="You only respond in haikus.",
)

# This will use the custom model provider
result = await Runner.run(
agent,
"Tell me about recursion in programming.",
run_config=RunConfig(model_provider=CUSTOM_MODEL_PROVIDER),
)
print(result.final_output)

# If you uncomment this, it will use OpenAI directly, not the custom provider
# result = await Runner.run(
# agent,
# "Tell me about recursion in programming.",
# )
# print(result.final_output)


if __name__ == "__main__":
asyncio.run(main())
14 changes: 10 additions & 4 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,19 @@
from .usage import Usage


def set_default_openai_key(key: str) -> None:
"""Set the default OpenAI API key to use for LLM requests and tracing. This is only necessary if
the OPENAI_API_KEY environment variable is not already set.
def set_default_openai_key(key: str, use_for_tracing: bool = True) -> None:
"""Set the default OpenAI API key to use for LLM requests (and optionally tracing(). This is
only necessary if the OPENAI_API_KEY environment variable is not already set.

If provided, this key will be used instead of the OPENAI_API_KEY environment variable.

Args:
key: The OpenAI key to use.
use_for_tracing: Whether to also use this key to send traces to OpenAI. Defaults to True
If False, you'll either need to set the OPENAI_API_KEY environment variable or call
set_tracing_export_api_key() with the API key you want to use for tracing.
"""
_config.set_default_openai_key(key)
_config.set_default_openai_key(key, use_for_tracing)


def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool = True) -> None:
Expand Down
9 changes: 6 additions & 3 deletions src/agents/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
from .tracing import set_tracing_export_api_key


def set_default_openai_key(key: str) -> None:
set_tracing_export_api_key(key)
def set_default_openai_key(key: str, use_for_tracing: bool) -> None:
_openai_shared.set_default_openai_key(key)

if use_for_tracing:
set_tracing_export_api_key(key)


def set_default_openai_client(client: AsyncOpenAI, use_for_tracing: bool) -> None:
_openai_shared.set_default_openai_client(client)

if use_for_tracing:
set_tracing_export_api_key(client.api_key)
_openai_shared.set_default_openai_client(client)


def set_default_openai_api(api: Literal["chat_completions", "responses"]) -> None:
Expand Down
35 changes: 24 additions & 11 deletions src/agents/models/openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,28 +38,41 @@ def __init__(
assert api_key is None and base_url is None, (
"Don't provide api_key or base_url if you provide openai_client"
)
self._client = openai_client
self._client: AsyncOpenAI | None = openai_client
else:
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
api_key=api_key or _openai_shared.get_default_openai_key(),
base_url=base_url,
organization=organization,
project=project,
http_client=shared_http_client(),
)
self._client = None
self._stored_api_key = api_key
self._stored_base_url = base_url
self._stored_organization = organization
self._stored_project = project

self._is_openai_model = self._client.base_url.host.startswith("api.openai.com")
if use_responses is not None:
self._use_responses = use_responses
else:
self._use_responses = _openai_shared.get_use_responses_by_default()

# We lazy load the client in case you never actually use OpenAIProvider(). Otherwise
# AsyncOpenAI() raises an error if you don't have an API key set.
def _get_client(self) -> AsyncOpenAI:
if self._client is None:
self._client = _openai_shared.get_default_openai_client() or AsyncOpenAI(
api_key=self._stored_api_key or _openai_shared.get_default_openai_key(),
base_url=self._stored_base_url,
organization=self._stored_organization,
project=self._stored_project,
http_client=shared_http_client(),
)

return self._client

def get_model(self, model_name: str | None) -> Model:
if model_name is None:
model_name = DEFAULT_MODEL

client = self._get_client()

return (
OpenAIResponsesModel(model=model_name, openai_client=self._client)
OpenAIResponsesModel(model=model_name, openai_client=client)
if self._use_responses
else OpenAIChatCompletionsModel(model=model_name, openai_client=self._client)
else OpenAIChatCompletionsModel(model=model_name, openai_client=client)
)