Skip to content

Commit

Permalink
allow performing all requests async
Browse files Browse the repository at this point in the history
  • Loading branch information
pschichtel committed Dec 18, 2024
1 parent f95cdd8 commit 5e08037
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions rasa_vier_cvg/cvg.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,18 @@ class CVGOutput(OutputChannel):
on_message: Callable[[UserMessage], Awaitable[Any]]
base_url: str
proxy: Optional[str]
task_container: TaskContainer = TaskContainer()

@classmethod
def name(cls) -> Text:
return CHANNEL_NAME

def __init__(self, callback_base_url: Text, on_message: Callable[[UserMessage], Awaitable[Any]], proxy: Optional[str] = None) -> None:
def __init__(self, callback_base_url: Text, on_message: Callable[[UserMessage], Awaitable[Any]], proxy: Optional[str], task_container: TaskContainer, blocking_output: bool) -> None:
self.on_message = on_message

self.base_url = callback_base_url.rstrip('/')
self.proxy = proxy
self.task_container = task_container
self.blocking_output = blocking_output

# This functionality can be used to ignore certain messages received by this channel.
# It can be used as a workaround for dialog setups that produce messages that should not be forwarded to CVG but still be tracked.
Expand All @@ -89,7 +90,7 @@ def __init__(self, callback_base_url: Text, on_message: Callable[[UserMessage],
def _is_ignored(self, custom_json) -> bool:
return custom_json is not None and "ignore" in custom_json and custom_json["ignore"] is True

async def _perform_request(self, path: str, method: str, data: Optional[any], dialog_id: Optional[str], retries: int = 0) -> (Optional[int], any):
async def _perform_request_sync(self, path: str, method: str, data: Optional[any], dialog_id: Optional[str], retries: int = 0) -> (Optional[int], any):
url = f"{self.base_url}{path}"
try:
async with aiohttp.request(method, url, json=data, proxy=self.proxy) as res:
Expand All @@ -107,17 +108,29 @@ async def _perform_request(self, path: str, method: str, data: Optional[any], di
except aiohttp.ClientConnectionError:
if retries < 3:
logger.error(f"{dialog_id} - The connection failed, retrying...")
await self._perform_request(path, method, data, dialog_id, retries + 1)
await self._perform_request_sync(path, method, data, dialog_id, retries + 1)
else:
logger.error(f"{dialog_id} - {retries} retries all failed, that's it!")

def _perform_request_async(self, path: str, method: str, data: Optional[any], dialog_id: Optional[str], process_result: Callable[[int, any], Coroutine[Any, Any, None]]):
async def perform():
status, body = await self._perform_request(path, method, data, dialog_id)
status, body = await self._perform_request_sync(path, method, data, dialog_id)
await process_result(status, body)

self.task_container.run(perform())

async def _perform_request(self, path: str, method: str, data: Optional[any], dialog_id: Optional[str]):
async def handle_result(status_code, response_body):
if not 200 <= status_code < 300:
logger.info(f"{dialog_id} - {method} request to {path} failed: {status_code} with body {response_body}")
return

if self.blocking_output:
result = await self._perform_request_sync(path, method, data, dialog_id)
await handle_result(*result)
else:
self._perform_request_async(path, method, data, dialog_id, handle_result)

async def _say(self, dialog_id: str, text: str):
if len(text.strip()) > 0:
await self._perform_request("/call/say", method="POST", data={DIALOG_ID_FIELD: dialog_id, "text": text}, dialog_id=dialog_id)
Expand Down Expand Up @@ -246,6 +259,7 @@ class CVGInput(InputChannel):
proxy: Optional[str]
expected_authorization_header_value: str
blocking_endpoints: bool
blocking_output: bool
ignore_messages_when_busy: bool
task_container: TaskContainer = TaskContainer()
# This Set is not thread safe. However, sanic is not multithreaded.
Expand All @@ -271,22 +285,28 @@ def from_credentials(cls, credentials: Optional[Dict[Text, Any]]) -> InputChanne
blocking_endpoints = True
else:
blocking_endpoints = bool(blocking_endpoints)
blocking_output = credentials.get("blocking_endpoints")
if blocking_output is None:
blocking_output = True
else:
blocking_output = bool(blocking_output)

ignore_messages_when_busy = credentials.get("ignore_messages_when_busy")
if ignore_messages_when_busy is None:
ignore_messages_when_busy = False
else:
ignore_messages_when_busy = bool(ignore_messages_when_busy)

logger.info(f"Creating input with: token={'*' * len(token)} proxy={proxy} start_intent={start_intent} blocking_endpoints={blocking_endpoints} ignore_messages_when_busy={ignore_messages_when_busy}")
return cls(token, start_intent, proxy, blocking_endpoints, ignore_messages_when_busy)
logger.info(f"Creating input with: token={'*' * len(token)} proxy={proxy} start_intent={start_intent} blocking_endpoints={blocking_endpoints} blocking_output={blocking_output} ignore_messages_when_busy={ignore_messages_when_busy}")
return cls(token, start_intent, proxy, blocking_endpoints, blocking_output, ignore_messages_when_busy)

def __init__(self, token: Text, start_intent: Text, proxy: Optional[Text], blocking_endpoints: bool, ignore_messages_when_busy: bool) -> None:
def __init__(self, token: Text, start_intent: Text, proxy: Optional[Text], blocking_endpoints: bool, blocking_output: bool, ignore_messages_when_busy: bool) -> None:
self.callback = None
self.expected_authorization_header_value = f"Bearer {token}"
self.proxy = proxy
self.start_intent = start_intent
self.blocking_endpoints = blocking_endpoints
self.blocking_output = blocking_output
self.ignore_messages_when_busy = ignore_messages_when_busy

async def _process_message(self, request: Request, on_new_message: Callable[[UserMessage], Awaitable[Any]], dialog_id: Text, text: Text, sender_id: Text) -> Any:
Expand All @@ -297,7 +317,7 @@ async def _process_message(self, request: Request, on_new_message: Callable[[Use
metadata = make_metadata(request.json)
user_msg = UserMessage(
text=text,
output_channel=CVGOutput(request.json[CALLBACK_FIELD], on_new_message, self.proxy),
output_channel=CVGOutput(request.json[CALLBACK_FIELD], on_new_message, self.proxy, self.task_container, self.blocking_output),
sender_id=sender_id,
input_channel=CHANNEL_NAME,
metadata=metadata,
Expand All @@ -316,7 +336,7 @@ async def _process_message(self, request: Request, on_new_message: Callable[[Use
try:
await on_new_message(user_msg)
finally:
if (self.ignore_messages_when_busy):
if self.ignore_messages_when_busy:
self.ignore_messages_for.remove(dialog_id)
except Exception as e:
logger.error(f"{dialog_id} - Exception when trying to handle message: {e}")
Expand Down

0 comments on commit 5e08037

Please sign in to comment.