Skip to content

Commit d2c6856

Browse files
committed
damn stateful websocket
1 parent ea9ce1a commit d2c6856

File tree

3 files changed

+23
-21
lines changed

3 files changed

+23
-21
lines changed

core/cat/auth/connection.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# to have a standard auth interface.
44

55
from abc import ABC, abstractmethod
6-
from typing import Tuple
6+
from typing import Tuple, AsyncGenerator
77
import asyncio
88
from urllib.parse import urlencode
99

@@ -37,7 +37,7 @@ def __init__(
3737
async def __call__(
3838
self,
3939
connection: HTTPConnection # Request | WebSocket,
40-
) -> StrayCat:
40+
) -> AsyncGenerator[StrayCat, None]:
4141

4242
# get protocol from Starlette request
4343
protocol = connection.scope.get('type')
@@ -57,7 +57,6 @@ async def __call__(
5757
stray = await self.get_user_stray(user, connection)
5858
yield stray
5959

60-
log.critical("STRAY FINITO")
6160
stray.update_working_memory_cache()
6261
del stray
6362
return

core/cat/looking_glass/stray_cat.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,32 +38,24 @@ def __init__(
3838
user_data: AuthUserInfo = None,
3939
ws: WebSocket = None,
4040
):
41+
"""Initialize the StrayCat object."""
42+
43+
# user data
4144
self.__user_id = user_id
4245
self.__user_data = user_data
4346

44-
# get working memory from cache or create a new one
45-
log.warning(f"GET working memory for {user_id}")
46-
self.working_memory = self.cache.get_value(f"{user_id}_working_memory") or WorkingMemory()
47-
4847
# attribute to store ws connection
4948
self.__ws = ws
5049

50+
# main event loop (for ws messages)
5151
self.__main_loop = main_loop
52+
53+
# get working memory from cache or create a new one
54+
self.load_working_memory_from_cache()
5255

5356
def __repr__(self):
5457
return f"StrayCat(user_id={self.user_id})"
5558

56-
#def __del__(self):
57-
# log.critical(f"StrayCat __del__ called for {self.user_id}")
58-
#self.__main_loop = None
59-
#self.__ws = None
60-
# when the garbage collector deletes the stray, we update working memory in cache
61-
# self.update_working_memory_cache()
62-
#self.__user_id = None
63-
#self.__user_data = None
64-
# del self
65-
66-
6759
def __send_ws_json(self, data: Any):
6860
# Run the corutine in the main event loop in the main thread
6961
# and wait for the result
@@ -101,9 +93,15 @@ def __build_why(self) -> MessageWhy:
10193

10294
return why
10395

96+
def load_working_memory_from_cache(self):
97+
"""Load the working memory from the cache."""
98+
log.warning(f"GET working memory for {self.user_id}")
99+
self.working_memory = \
100+
self.cache.get_value(f"{self.user_id}_working_memory") or WorkingMemory()
101+
104102
def update_working_memory_cache(self):
105103
"""Update the working memory in the cache."""
106-
log.warning(f"SAVE working memory for {self.user_id}")
104+
log.critical(f"SAVE {self.user_id}")
107105
updated_cache_item = CacheItem(f"{self.user_id}_working_memory", self.working_memory, -1)
108106
self.cache.insert(updated_cache_item)
109107

@@ -392,7 +390,7 @@ def __call__(self, message_dict):
392390
user_message = UserMessage.model_validate(message_dict)
393391
log.info(user_message)
394392

395-
### setup working memory
393+
### setup working memory for this convo turn
396394
# keeping track of model interactions
397395
self.working_memory.model_interactions = []
398396
# latest user message
@@ -488,8 +486,14 @@ def __call__(self, message_dict):
488486

489487
def run(self, user_message_json, return_message=False):
490488
try:
489+
490+
# load working memory from cache
491+
self.load_working_memory_from_cache()
492+
# run main flow
491493
cat_message = self.__call__(user_message_json)
494+
# save working memory to cache
492495
self.update_working_memory_cache()
496+
493497
if return_message:
494498
# return the message for HTTP usage
495499
return cat_message

core/cat/routes/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,4 @@ async def message_with_cat(
2929
"""Get a response from the Cat"""
3030
user_message_json = {"user_id": stray.user_id, **payload}
3131
answer = await run_in_threadpool(stray.run, user_message_json, True)
32-
del stray
3332
return answer

0 commit comments

Comments
 (0)