|
| 1 | +import asyncio |
1 | 2 | import json
|
2 | 3 | import os
|
3 | 4 | import uuid
|
4 | 5 | from dataclasses import dataclass
|
5 | 6 | from datetime import timedelta
|
6 |
| -from typing import Any, Optional, Union, no_type_check |
| 7 | +from typing import Any, AsyncIterator, Optional, Union, no_type_check |
7 | 8 |
|
8 | 9 | import nexusrpc
|
9 | 10 | import pytest
|
|
39 | 40 | HandoffOutputItem,
|
40 | 41 | ToolCallItem,
|
41 | 42 | ToolCallOutputItem,
|
| 43 | + TResponseStreamEvent, |
42 | 44 | )
|
43 | 45 | from openai import APIStatusError, AsyncOpenAI, BaseModel
|
44 | 46 | from openai.types.responses import (
|
@@ -1876,3 +1878,83 @@ async def test_chat_completions_model(client: Client):
|
1876 | 1878 | execution_timeout=timedelta(seconds=10),
|
1877 | 1879 | )
|
1878 | 1880 | await workflow_handle.result()
|
| 1881 | + |
| 1882 | + |
| 1883 | +class WaitModel(Model): |
| 1884 | + async def get_response( |
| 1885 | + self, |
| 1886 | + system_instructions: Union[str, None], |
| 1887 | + input: Union[str, list[TResponseInputItem]], |
| 1888 | + model_settings: ModelSettings, |
| 1889 | + tools: list[Tool], |
| 1890 | + output_schema: Union[AgentOutputSchemaBase, None], |
| 1891 | + handoffs: list[Handoff], |
| 1892 | + tracing: ModelTracing, |
| 1893 | + *, |
| 1894 | + previous_response_id: Union[str, None], |
| 1895 | + prompt: Union[ResponsePromptParam, None] = None, |
| 1896 | + ) -> ModelResponse: |
| 1897 | + activity.logger.info("Waiting") |
| 1898 | + await asyncio.sleep(1.0) |
| 1899 | + activity.logger.info("Returning") |
| 1900 | + return ModelResponse( |
| 1901 | + output=[ |
| 1902 | + ResponseOutputMessage( |
| 1903 | + id="", |
| 1904 | + content=[ |
| 1905 | + ResponseOutputText( |
| 1906 | + text="test", annotations=[], type="output_text" |
| 1907 | + ) |
| 1908 | + ], |
| 1909 | + role="assistant", |
| 1910 | + status="completed", |
| 1911 | + type="message", |
| 1912 | + ) |
| 1913 | + ], |
| 1914 | + usage=Usage(), |
| 1915 | + response_id=None, |
| 1916 | + ) |
| 1917 | + |
| 1918 | + def stream_response( |
| 1919 | + self, |
| 1920 | + system_instructions: Optional[str], |
| 1921 | + input: Union[str, list[TResponseInputItem]], |
| 1922 | + model_settings: ModelSettings, |
| 1923 | + tools: list[Tool], |
| 1924 | + output_schema: Optional[AgentOutputSchemaBase], |
| 1925 | + handoffs: list[Handoff], |
| 1926 | + tracing: ModelTracing, |
| 1927 | + *, |
| 1928 | + previous_response_id: Optional[str], |
| 1929 | + prompt: Optional[ResponsePromptParam], |
| 1930 | + ) -> AsyncIterator[TResponseStreamEvent]: |
| 1931 | + raise NotImplementedError() |
| 1932 | + |
| 1933 | + |
| 1934 | +async def test_heartbeat(client: Client, env: WorkflowEnvironment): |
| 1935 | + if env.supports_time_skipping: |
| 1936 | + pytest.skip("Relies on real timing, skip.") |
| 1937 | + |
| 1938 | + new_config = client.config() |
| 1939 | + new_config["plugins"] = [ |
| 1940 | + openai_agents.OpenAIAgentsPlugin( |
| 1941 | + model_params=ModelActivityParameters( |
| 1942 | + heartbeat_timeout=timedelta(seconds=0.5), |
| 1943 | + ), |
| 1944 | + model_provider=TestModelProvider(WaitModel()), |
| 1945 | + ) |
| 1946 | + ] |
| 1947 | + client = Client(**new_config) |
| 1948 | + |
| 1949 | + async with new_worker( |
| 1950 | + client, |
| 1951 | + HelloWorldAgent, |
| 1952 | + ) as worker: |
| 1953 | + workflow_handle = await client.start_workflow( |
| 1954 | + HelloWorldAgent.run, |
| 1955 | + "Tell me about recursion in programming.", |
| 1956 | + id=f"workflow-tool-{uuid.uuid4()}", |
| 1957 | + task_queue=worker.task_queue, |
| 1958 | + execution_timeout=timedelta(seconds=5.0), |
| 1959 | + ) |
| 1960 | + await workflow_handle.result() |
0 commit comments