Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 66 additions & 72 deletions src/agents/realtime/openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ async def send_event(self, event: RealtimeModelSendEvent) -> None:
async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None:
"""Send a raw message to the model."""
assert self._websocket is not None, "Not connected"
payload = event.model_dump_json(exclude_none=True, exclude_unset=True)
payload = event.model_dump_json(exclude_unset=True)
await self._websocket.send(payload)

async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None:
Expand Down Expand Up @@ -829,91 +829,63 @@ def _get_session_config(
self, model_settings: RealtimeSessionModelSettings
) -> OpenAISessionCreateRequest:
"""Get the session config."""
model_name = (model_settings.get("model_name") or self.model) or "gpt-realtime"

voice = model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice"))
speed = model_settings.get("speed")
modalities = model_settings.get("modalities", DEFAULT_MODEL_SETTINGS.get("modalities"))
audio_input_args = {}

if self._call_id:
input_audio_format = model_settings.get("input_audio_format")
else:
input_audio_format = model_settings.get(
"input_audio_format",
DEFAULT_MODEL_SETTINGS.get("input_audio_format"),
audio_input_args["format"] = to_realtime_audio_format(
model_settings.get("input_audio_format")
)
input_audio_transcription = model_settings.get(
"input_audio_transcription",
DEFAULT_MODEL_SETTINGS.get("input_audio_transcription"),
)
turn_detection = model_settings.get(
"turn_detection",
DEFAULT_MODEL_SETTINGS.get("turn_detection"),
)
if self._call_id:
output_audio_format = model_settings.get("output_audio_format")
else:
output_audio_format = model_settings.get(
"output_audio_format",
DEFAULT_MODEL_SETTINGS.get("output_audio_format"),
audio_input_args["format"] = to_realtime_audio_format(
model_settings.get(
"input_audio_format", DEFAULT_MODEL_SETTINGS.get("input_audio_format")
)
)
input_audio_noise_reduction = model_settings.get(
"input_audio_noise_reduction",
DEFAULT_MODEL_SETTINGS.get("input_audio_noise_reduction"),
)

input_audio_config = None
if any(
value is not None
for value in [
input_audio_format,
input_audio_noise_reduction,
input_audio_transcription,
turn_detection,
]
):
input_audio_config = OpenAIRealtimeAudioInput(
format=to_realtime_audio_format(input_audio_format),
noise_reduction=cast(Any, input_audio_noise_reduction),
transcription=cast(Any, input_audio_transcription),
turn_detection=cast(Any, turn_detection),
)
if "input_audio_noise_reduction" in model_settings:
audio_input_args["noise_reduction"] = model_settings.get("input_audio_noise_reduction") # type: ignore[assignment]

output_audio_config = None
if any(value is not None for value in [output_audio_format, speed, voice]):
output_audio_config = OpenAIRealtimeAudioOutput(
format=to_realtime_audio_format(output_audio_format),
speed=speed,
voice=voice,
if "input_audio_transcription" in model_settings:
audio_input_args["transcription"] = model_settings.get("input_audio_transcription") # type: ignore[assignment]
else:
audio_input_args["transcription"] = DEFAULT_MODEL_SETTINGS.get( # type: ignore[assignment]
"input_audio_transcription"
)

audio_config = None
if input_audio_config or output_audio_config:
audio_config = OpenAIRealtimeAudioConfig(
input=input_audio_config,
output=output_audio_config,
)
if "turn_detection" in model_settings:
audio_input_args["turn_detection"] = model_settings.get("turn_detection") # type: ignore[assignment]
else:
audio_input_args["turn_detection"] = DEFAULT_MODEL_SETTINGS.get("turn_detection") # type: ignore[assignment]

prompt: ResponsePrompt | None = None
if model_settings.get("prompt") is not None:
_passed_prompt: Prompt = model_settings["prompt"]
variables: dict[str, Any] | None = _passed_prompt.get("variables")
prompt = ResponsePrompt(
id=_passed_prompt["id"],
variables=variables,
version=_passed_prompt.get("version"),
audio_output_args = {
"voice": model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice")),
}

if self._call_id:
audio_output_args["format"] = to_realtime_audio_format( # type: ignore[assignment]
model_settings.get("output_audio_format")
)
else:
audio_output_args["format"] = to_realtime_audio_format( # type: ignore[assignment]
model_settings.get(
"output_audio_format", DEFAULT_MODEL_SETTINGS.get("output_audio_format")
)
)

if "speed" in model_settings:
audio_output_args["speed"] = model_settings.get("speed") # type: ignore[assignment]

# Construct full session object. `type` will be excluded at serialization time for updates.
return OpenAISessionCreateRequest(
model=model_name,
session_create_request = OpenAISessionCreateRequest(
type="realtime",
instructions=model_settings.get("instructions"),
prompt=prompt,
output_modalities=modalities,
audio=audio_config,
max_output_tokens=cast(Any, model_settings.get("max_output_tokens")),
tool_choice=cast(Any, model_settings.get("tool_choice")),
model=(model_settings.get("model_name") or self.model) or "gpt-realtime",
output_modalities=model_settings.get(
"modalities", DEFAULT_MODEL_SETTINGS.get("modalities")
),
audio=OpenAIRealtimeAudioConfig(
input=OpenAIRealtimeAudioInput(**audio_input_args), # type: ignore[arg-type]
output=OpenAIRealtimeAudioOutput(**audio_output_args), # type: ignore[arg-type]
),
tools=cast(
Any,
self._tools_to_session_tools(
Expand All @@ -923,6 +895,28 @@ def _get_session_config(
),
)

if "instructions" in model_settings:
session_create_request.instructions = model_settings.get("instructions")

if "prompt" in model_settings:
_passed_prompt: Prompt = model_settings["prompt"]
variables: dict[str, Any] | None = _passed_prompt.get("variables")
session_create_request.prompt = ResponsePrompt(
id=_passed_prompt["id"],
variables=variables,
version=_passed_prompt.get("version"),
)

if "max_output_tokens" in model_settings:
session_create_request.max_output_tokens = cast(
Any, model_settings.get("max_output_tokens")
)

if "tool_choice" in model_settings:
session_create_request.tool_choice = cast(Any, model_settings.get("tool_choice"))

return session_create_request

def _tools_to_session_tools(
self, tools: list[Tool], handoffs: list[Handoff]
) -> list[OpenAISessionFunction]:
Expand Down
40 changes: 40 additions & 0 deletions tests/realtime/test_openai_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,46 @@ async def test_connect_already_connected_assertion(self, model, mock_websocket):
with pytest.raises(AssertionError, match="Already connected"):
await model.connect(config)

@pytest.mark.asyncio
async def test_session_update_disable_turn_detection(self, model, mock_websocket):
"""Session.update should allow users to disable turn-detection."""
config = {
"api_key": "test-api-key-123",
"initial_model_settings": {
"model_name": "gpt-4o-realtime-preview",
"turn_detection": None,
},
}

sent_messages: list[dict[str, Any]] = []

async def async_websocket(*args, **kwargs):
async def send(payload: str):
sent_messages.append(json.loads(payload))
return None

mock_websocket.send.side_effect = send
return mock_websocket

with patch("websockets.connect", side_effect=async_websocket):
with patch("asyncio.create_task") as mock_create_task:
mock_task = AsyncMock()

def mock_create_task_func(coro):
coro.close()
return mock_task

mock_create_task.side_effect = mock_create_task_func
await model.connect(config)

# Find the session.update events
session_updates = [m for m in sent_messages if m.get("type") == "session.update"]
assert len(session_updates) >= 1
# Verify the last session.update omits the noise_reduction field
session = session_updates[-1]["session"]
assert "audio" in session and "input" in session["audio"]
assert session["audio"]["input"]["turn_detection"] is None


class TestEventHandlingRobustness(TestOpenAIRealtimeWebSocketModel):
"""Test event parsing, validation, and error handling robustness."""
Expand Down