Skip to content

Commit

Permalink
damn stateful websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
pieroit committed Feb 6, 2025
1 parent ea9ce1a commit d2c6856
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
5 changes: 2 additions & 3 deletions core/cat/auth/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# to have a standard auth interface.

from abc import ABC, abstractmethod
from typing import Tuple
from typing import Tuple, AsyncGenerator
import asyncio
from urllib.parse import urlencode

Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
async def __call__(
self,
connection: HTTPConnection # Request | WebSocket,
) -> StrayCat:
) -> AsyncGenerator[StrayCat, None]:

# get protocol from Starlette request
protocol = connection.scope.get('type')
Expand All @@ -57,7 +57,6 @@ async def __call__(
stray = await self.get_user_stray(user, connection)
yield stray

log.critical("STRAY FINITO")
stray.update_working_memory_cache()
del stray
return
Expand Down
38 changes: 21 additions & 17 deletions core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,32 +38,24 @@ def __init__(
user_data: AuthUserInfo = None,
ws: WebSocket = None,
):
"""Initialize the StrayCat object."""

# user data
self.__user_id = user_id
self.__user_data = user_data

# get working memory from cache or create a new one
log.warning(f"GET working memory for {user_id}")
self.working_memory = self.cache.get_value(f"{user_id}_working_memory") or WorkingMemory()

# attribute to store ws connection
self.__ws = ws

# main event loop (for ws messages)
self.__main_loop = main_loop

# get working memory from cache or create a new one
self.load_working_memory_from_cache()

def __repr__(self):
return f"StrayCat(user_id={self.user_id})"

#def __del__(self):
# log.critical(f"StrayCat __del__ called for {self.user_id}")
#self.__main_loop = None
#self.__ws = None
# when the garbage collector deletes the stray, we update working memory in cache
# self.update_working_memory_cache()
#self.__user_id = None
#self.__user_data = None
# del self


def __send_ws_json(self, data: Any):
# Run the corutine in the main event loop in the main thread
# and wait for the result
Expand Down Expand Up @@ -101,9 +93,15 @@ def __build_why(self) -> MessageWhy:

return why

def load_working_memory_from_cache(self):
"""Load the working memory from the cache."""
log.warning(f"GET working memory for {self.user_id}")
self.working_memory = \
self.cache.get_value(f"{self.user_id}_working_memory") or WorkingMemory()

def update_working_memory_cache(self):
"""Update the working memory in the cache."""
log.warning(f"SAVE working memory for {self.user_id}")
log.critical(f"SAVE {self.user_id}")
updated_cache_item = CacheItem(f"{self.user_id}_working_memory", self.working_memory, -1)
self.cache.insert(updated_cache_item)

Expand Down Expand Up @@ -392,7 +390,7 @@ def __call__(self, message_dict):
user_message = UserMessage.model_validate(message_dict)
log.info(user_message)

### setup working memory
### setup working memory for this convo turn
# keeping track of model interactions
self.working_memory.model_interactions = []
# latest user message
Expand Down Expand Up @@ -488,8 +486,14 @@ def __call__(self, message_dict):

def run(self, user_message_json, return_message=False):
try:

# load working memory from cache
self.load_working_memory_from_cache()
# run main flow
cat_message = self.__call__(user_message_json)
# save working memory to cache
self.update_working_memory_cache()

if return_message:
# return the message for HTTP usage
return cat_message
Expand Down
1 change: 0 additions & 1 deletion core/cat/routes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,4 @@ async def message_with_cat(
"""Get a response from the Cat"""
user_message_json = {"user_id": stray.user_id, **payload}
answer = await run_in_threadpool(stray.run, user_message_json, True)
del stray
return answer

0 comments on commit d2c6856

Please sign in to comment.