Skip to content

Commit e0e6794

Browse files
authored
Allow user to send None values to Realtime API (#2152)
1 parent 9edc53f commit e0e6794

File tree

2 files changed

+106
-72
lines changed

2 files changed

+106
-72
lines changed

src/agents/realtime/openai_realtime.py

Lines changed: 66 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ async def send_event(self, event: RealtimeModelSendEvent) -> None:
348348
async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None:
349349
"""Send a raw message to the model."""
350350
assert self._websocket is not None, "Not connected"
351-
payload = event.model_dump_json(exclude_none=True, exclude_unset=True)
351+
payload = event.model_dump_json(exclude_unset=True)
352352
await self._websocket.send(payload)
353353

354354
async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None:
@@ -829,91 +829,63 @@ def _get_session_config(
829829
self, model_settings: RealtimeSessionModelSettings
830830
) -> OpenAISessionCreateRequest:
831831
"""Get the session config."""
832-
model_name = (model_settings.get("model_name") or self.model) or "gpt-realtime"
833-
834-
voice = model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice"))
835-
speed = model_settings.get("speed")
836-
modalities = model_settings.get("modalities", DEFAULT_MODEL_SETTINGS.get("modalities"))
832+
audio_input_args = {}
837833

838834
if self._call_id:
839-
input_audio_format = model_settings.get("input_audio_format")
840-
else:
841-
input_audio_format = model_settings.get(
842-
"input_audio_format",
843-
DEFAULT_MODEL_SETTINGS.get("input_audio_format"),
835+
audio_input_args["format"] = to_realtime_audio_format(
836+
model_settings.get("input_audio_format")
844837
)
845-
input_audio_transcription = model_settings.get(
846-
"input_audio_transcription",
847-
DEFAULT_MODEL_SETTINGS.get("input_audio_transcription"),
848-
)
849-
turn_detection = model_settings.get(
850-
"turn_detection",
851-
DEFAULT_MODEL_SETTINGS.get("turn_detection"),
852-
)
853-
if self._call_id:
854-
output_audio_format = model_settings.get("output_audio_format")
855838
else:
856-
output_audio_format = model_settings.get(
857-
"output_audio_format",
858-
DEFAULT_MODEL_SETTINGS.get("output_audio_format"),
839+
audio_input_args["format"] = to_realtime_audio_format(
840+
model_settings.get(
841+
"input_audio_format", DEFAULT_MODEL_SETTINGS.get("input_audio_format")
842+
)
859843
)
860-
input_audio_noise_reduction = model_settings.get(
861-
"input_audio_noise_reduction",
862-
DEFAULT_MODEL_SETTINGS.get("input_audio_noise_reduction"),
863-
)
864844

865-
input_audio_config = None
866-
if any(
867-
value is not None
868-
for value in [
869-
input_audio_format,
870-
input_audio_noise_reduction,
871-
input_audio_transcription,
872-
turn_detection,
873-
]
874-
):
875-
input_audio_config = OpenAIRealtimeAudioInput(
876-
format=to_realtime_audio_format(input_audio_format),
877-
noise_reduction=cast(Any, input_audio_noise_reduction),
878-
transcription=cast(Any, input_audio_transcription),
879-
turn_detection=cast(Any, turn_detection),
880-
)
845+
if "input_audio_noise_reduction" in model_settings:
846+
audio_input_args["noise_reduction"] = model_settings.get("input_audio_noise_reduction") # type: ignore[assignment]
881847

882-
output_audio_config = None
883-
if any(value is not None for value in [output_audio_format, speed, voice]):
884-
output_audio_config = OpenAIRealtimeAudioOutput(
885-
format=to_realtime_audio_format(output_audio_format),
886-
speed=speed,
887-
voice=voice,
848+
if "input_audio_transcription" in model_settings:
849+
audio_input_args["transcription"] = model_settings.get("input_audio_transcription") # type: ignore[assignment]
850+
else:
851+
audio_input_args["transcription"] = DEFAULT_MODEL_SETTINGS.get( # type: ignore[assignment]
852+
"input_audio_transcription"
888853
)
889854

890-
audio_config = None
891-
if input_audio_config or output_audio_config:
892-
audio_config = OpenAIRealtimeAudioConfig(
893-
input=input_audio_config,
894-
output=output_audio_config,
895-
)
855+
if "turn_detection" in model_settings:
856+
audio_input_args["turn_detection"] = model_settings.get("turn_detection") # type: ignore[assignment]
857+
else:
858+
audio_input_args["turn_detection"] = DEFAULT_MODEL_SETTINGS.get("turn_detection") # type: ignore[assignment]
896859

897-
prompt: ResponsePrompt | None = None
898-
if model_settings.get("prompt") is not None:
899-
_passed_prompt: Prompt = model_settings["prompt"]
900-
variables: dict[str, Any] | None = _passed_prompt.get("variables")
901-
prompt = ResponsePrompt(
902-
id=_passed_prompt["id"],
903-
variables=variables,
904-
version=_passed_prompt.get("version"),
860+
audio_output_args = {
861+
"voice": model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice")),
862+
}
863+
864+
if self._call_id:
865+
audio_output_args["format"] = to_realtime_audio_format( # type: ignore[assignment]
866+
model_settings.get("output_audio_format")
905867
)
868+
else:
869+
audio_output_args["format"] = to_realtime_audio_format( # type: ignore[assignment]
870+
model_settings.get(
871+
"output_audio_format", DEFAULT_MODEL_SETTINGS.get("output_audio_format")
872+
)
873+
)
874+
875+
if "speed" in model_settings:
876+
audio_output_args["speed"] = model_settings.get("speed") # type: ignore[assignment]
906877

907878
# Construct full session object. `type` will be excluded at serialization time for updates.
908-
return OpenAISessionCreateRequest(
909-
model=model_name,
879+
session_create_request = OpenAISessionCreateRequest(
910880
type="realtime",
911-
instructions=model_settings.get("instructions"),
912-
prompt=prompt,
913-
output_modalities=modalities,
914-
audio=audio_config,
915-
max_output_tokens=cast(Any, model_settings.get("max_output_tokens")),
916-
tool_choice=cast(Any, model_settings.get("tool_choice")),
881+
model=(model_settings.get("model_name") or self.model) or "gpt-realtime",
882+
output_modalities=model_settings.get(
883+
"modalities", DEFAULT_MODEL_SETTINGS.get("modalities")
884+
),
885+
audio=OpenAIRealtimeAudioConfig(
886+
input=OpenAIRealtimeAudioInput(**audio_input_args), # type: ignore[arg-type]
887+
output=OpenAIRealtimeAudioOutput(**audio_output_args), # type: ignore[arg-type]
888+
),
917889
tools=cast(
918890
Any,
919891
self._tools_to_session_tools(
@@ -923,6 +895,28 @@ def _get_session_config(
923895
),
924896
)
925897

898+
if "instructions" in model_settings:
899+
session_create_request.instructions = model_settings.get("instructions")
900+
901+
if "prompt" in model_settings:
902+
_passed_prompt: Prompt = model_settings["prompt"]
903+
variables: dict[str, Any] | None = _passed_prompt.get("variables")
904+
session_create_request.prompt = ResponsePrompt(
905+
id=_passed_prompt["id"],
906+
variables=variables,
907+
version=_passed_prompt.get("version"),
908+
)
909+
910+
if "max_output_tokens" in model_settings:
911+
session_create_request.max_output_tokens = cast(
912+
Any, model_settings.get("max_output_tokens")
913+
)
914+
915+
if "tool_choice" in model_settings:
916+
session_create_request.tool_choice = cast(Any, model_settings.get("tool_choice"))
917+
918+
return session_create_request
919+
926920
def _tools_to_session_tools(
927921
self, tools: list[Tool], handoffs: list[Handoff]
928922
) -> list[OpenAISessionFunction]:

tests/realtime/test_openai_realtime.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,46 @@ async def test_connect_already_connected_assertion(self, model, mock_websocket):
309309
with pytest.raises(AssertionError, match="Already connected"):
310310
await model.connect(config)
311311

312+
@pytest.mark.asyncio
313+
async def test_session_update_disable_turn_detection(self, model, mock_websocket):
314+
"""Session.update should allow users to disable turn-detection."""
315+
config = {
316+
"api_key": "test-api-key-123",
317+
"initial_model_settings": {
318+
"model_name": "gpt-4o-realtime-preview",
319+
"turn_detection": None,
320+
},
321+
}
322+
323+
sent_messages: list[dict[str, Any]] = []
324+
325+
async def async_websocket(*args, **kwargs):
326+
async def send(payload: str):
327+
sent_messages.append(json.loads(payload))
328+
return None
329+
330+
mock_websocket.send.side_effect = send
331+
return mock_websocket
332+
333+
with patch("websockets.connect", side_effect=async_websocket):
334+
with patch("asyncio.create_task") as mock_create_task:
335+
mock_task = AsyncMock()
336+
337+
def mock_create_task_func(coro):
338+
coro.close()
339+
return mock_task
340+
341+
mock_create_task.side_effect = mock_create_task_func
342+
await model.connect(config)
343+
344+
# Find the session.update events
345+
session_updates = [m for m in sent_messages if m.get("type") == "session.update"]
346+
assert len(session_updates) >= 1
347+
# Verify the last session.update omits the noise_reduction field
348+
session = session_updates[-1]["session"]
349+
assert "audio" in session and "input" in session["audio"]
350+
assert session["audio"]["input"]["turn_detection"] is None
351+
312352

313353
class TestEventHandlingRobustness(TestOpenAIRealtimeWebSocketModel):
314354
"""Test event parsing, validation, and error handling robustness."""

0 commit comments

Comments
 (0)