Skip to content

Commit d863f5c

Browse files
authored
Add heartbeat test and fix bug (#984)
* Add heartbeat test and fix bug * Reduce intervals for speed * Linting * Skip new test on time skipping server * Update timings
1 parent f815886 commit d863f5c

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

temporalio/contrib/openai_agents/_heartbeat_decorator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
2424
if heartbeat_task:
2525
heartbeat_task.cancel()
2626
# Wait for heartbeat cancellation to complete
27-
await heartbeat_task
27+
try:
28+
await heartbeat_task
29+
except asyncio.CancelledError:
30+
pass
2831

2932
return cast(F, wrapper)
3033

tests/contrib/openai_agents/test_openai.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import asyncio
12
import json
23
import os
34
import uuid
45
from dataclasses import dataclass
56
from datetime import timedelta
6-
from typing import Any, Optional, Union, no_type_check
7+
from typing import Any, AsyncIterator, Optional, Union, no_type_check
78

89
import nexusrpc
910
import pytest
@@ -39,6 +40,7 @@
3940
HandoffOutputItem,
4041
ToolCallItem,
4142
ToolCallOutputItem,
43+
TResponseStreamEvent,
4244
)
4345
from openai import APIStatusError, AsyncOpenAI, BaseModel
4446
from openai.types.responses import (
@@ -1876,3 +1878,83 @@ async def test_chat_completions_model(client: Client):
18761878
execution_timeout=timedelta(seconds=10),
18771879
)
18781880
await workflow_handle.result()
1881+
1882+
1883+
class WaitModel(Model):
1884+
async def get_response(
1885+
self,
1886+
system_instructions: Union[str, None],
1887+
input: Union[str, list[TResponseInputItem]],
1888+
model_settings: ModelSettings,
1889+
tools: list[Tool],
1890+
output_schema: Union[AgentOutputSchemaBase, None],
1891+
handoffs: list[Handoff],
1892+
tracing: ModelTracing,
1893+
*,
1894+
previous_response_id: Union[str, None],
1895+
prompt: Union[ResponsePromptParam, None] = None,
1896+
) -> ModelResponse:
1897+
activity.logger.info("Waiting")
1898+
await asyncio.sleep(1.0)
1899+
activity.logger.info("Returning")
1900+
return ModelResponse(
1901+
output=[
1902+
ResponseOutputMessage(
1903+
id="",
1904+
content=[
1905+
ResponseOutputText(
1906+
text="test", annotations=[], type="output_text"
1907+
)
1908+
],
1909+
role="assistant",
1910+
status="completed",
1911+
type="message",
1912+
)
1913+
],
1914+
usage=Usage(),
1915+
response_id=None,
1916+
)
1917+
1918+
def stream_response(
1919+
self,
1920+
system_instructions: Optional[str],
1921+
input: Union[str, list[TResponseInputItem]],
1922+
model_settings: ModelSettings,
1923+
tools: list[Tool],
1924+
output_schema: Optional[AgentOutputSchemaBase],
1925+
handoffs: list[Handoff],
1926+
tracing: ModelTracing,
1927+
*,
1928+
previous_response_id: Optional[str],
1929+
prompt: Optional[ResponsePromptParam],
1930+
) -> AsyncIterator[TResponseStreamEvent]:
1931+
raise NotImplementedError()
1932+
1933+
1934+
async def test_heartbeat(client: Client, env: WorkflowEnvironment):
1935+
if env.supports_time_skipping:
1936+
pytest.skip("Relies on real timing, skip.")
1937+
1938+
new_config = client.config()
1939+
new_config["plugins"] = [
1940+
openai_agents.OpenAIAgentsPlugin(
1941+
model_params=ModelActivityParameters(
1942+
heartbeat_timeout=timedelta(seconds=0.5),
1943+
),
1944+
model_provider=TestModelProvider(WaitModel()),
1945+
)
1946+
]
1947+
client = Client(**new_config)
1948+
1949+
async with new_worker(
1950+
client,
1951+
HelloWorldAgent,
1952+
) as worker:
1953+
workflow_handle = await client.start_workflow(
1954+
HelloWorldAgent.run,
1955+
"Tell me about recursion in programming.",
1956+
id=f"workflow-tool-{uuid.uuid4()}",
1957+
task_queue=worker.task_queue,
1958+
execution_timeout=timedelta(seconds=5.0),
1959+
)
1960+
await workflow_handle.result()

0 commit comments

Comments
 (0)