Skip to content

Commit a786dba

Browse files
committed
keep references to background tasks
1 parent 81f126f commit a786dba

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

rasa_vier_cvg/cvg.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import base64
55
import logging
66
from functools import wraps
7-
from typing import Any, Awaitable, Callable, Dict, Optional, Text, TypeVar, Coroutine
7+
from typing import Any, Awaitable, Callable, Dict, Optional, Text, TypeVar, Coroutine, Set
88
import warnings
99
import aiohttp
1010

@@ -52,12 +52,22 @@ def create_recipient_id(reseller_token, project_token, dialog_id) -> Text:
5252
return base64.b64encode(bytes(json_representation, 'utf-8')).decode('utf-8')
5353

5454

55+
class TaskContainer:
56+
tasks: Set[asyncio.Task] = set()
57+
58+
def run(self, coro: Coroutine[Any, Any, None]):
59+
task = asyncio.create_task(coro)
60+
self.tasks.add(task)
61+
task.add_done_callback(self.tasks.discard)
62+
63+
5564
class CVGOutput(OutputChannel):
5665
"""Output channel for the Cognitive Voice Gateway"""
5766

5867
on_message: Callable[[UserMessage], Awaitable[Any]]
5968
base_url: str
6069
proxy: Optional[str]
70+
task_container: TaskContainer = TaskContainer()
6171

6272
@classmethod
6373
def name(cls) -> Text:
@@ -89,8 +99,7 @@ async def perform():
8999
status, body = await self._perform_request(path, method, data)
90100
await process_result(status, body)
91101

92-
# noinspection PyAsyncCall
93-
asyncio.create_task(perform())
102+
self.task_container.run(perform())
94103

95104
async def _say(self, dialog_id: str, text: str):
96105
await self._perform_request("/call/say", method="POST", data={DIALOG_ID_FIELD: dialog_id, "text": text})
@@ -214,6 +223,7 @@ class CVGInput(InputChannel):
214223
proxy: Optional[str]
215224
expected_authorization_header_value: str
216225
blocking_endpoints: bool
226+
task_container: TaskContainer = TaskContainer()
217227

218228
@classmethod
219229
def name(cls) -> Text:
@@ -316,8 +326,7 @@ async def process_request(request: Request, text: Text, must_block: bool):
316326
if self.blocking_endpoints or must_block:
317327
await result
318328
else:
319-
# noinspection PyAsyncCall
320-
asyncio.create_task(result)
329+
self.task_container.run(result)
321330

322331
return response.empty(204)
323332

0 commit comments

Comments
 (0)