Skip to content

Commit

Permalink
keep references to background tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
pschichtel committed Jan 8, 2024
1 parent 81f126f commit a786dba
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions rasa_vier_cvg/cvg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import base64
import logging
from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Optional, Text, TypeVar, Coroutine
from typing import Any, Awaitable, Callable, Dict, Optional, Text, TypeVar, Coroutine, Set
import warnings
import aiohttp

Expand Down Expand Up @@ -52,12 +52,22 @@ def create_recipient_id(reseller_token, project_token, dialog_id) -> Text:
return base64.b64encode(bytes(json_representation, 'utf-8')).decode('utf-8')


class TaskContainer:
tasks: Set[asyncio.Task] = set()

def run(self, coro: Coroutine[Any, Any, None]):
task = asyncio.create_task(coro)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)


class CVGOutput(OutputChannel):
"""Output channel for the Cognitive Voice Gateway"""

on_message: Callable[[UserMessage], Awaitable[Any]]
base_url: str
proxy: Optional[str]
task_container: TaskContainer = TaskContainer()

@classmethod
def name(cls) -> Text:
Expand Down Expand Up @@ -89,8 +99,7 @@ async def perform():
status, body = await self._perform_request(path, method, data)
await process_result(status, body)

# noinspection PyAsyncCall
asyncio.create_task(perform())
self.task_container.run(perform())

async def _say(self, dialog_id: str, text: str):
await self._perform_request("/call/say", method="POST", data={DIALOG_ID_FIELD: dialog_id, "text": text})
Expand Down Expand Up @@ -214,6 +223,7 @@ class CVGInput(InputChannel):
proxy: Optional[str]
expected_authorization_header_value: str
blocking_endpoints: bool
task_container: TaskContainer = TaskContainer()

@classmethod
def name(cls) -> Text:
Expand Down Expand Up @@ -316,8 +326,7 @@ async def process_request(request: Request, text: Text, must_block: bool):
if self.blocking_endpoints or must_block:
await result
else:
# noinspection PyAsyncCall
asyncio.create_task(result)
self.task_container.run(result)

return response.empty(204)

Expand Down

0 comments on commit a786dba

Please sign in to comment.