@@ -38,32 +38,24 @@ def __init__(
3838 user_data : AuthUserInfo = None ,
3939 ws : WebSocket = None ,
4040 ):
41+ """Initialize the StrayCat object."""
42+
43+ # user data
4144 self .__user_id = user_id
4245 self .__user_data = user_data
4346
44- # get working memory from cache or create a new one
45- log .warning (f"GET working memory for { user_id } " )
46- self .working_memory = self .cache .get_value (f"{ user_id } _working_memory" ) or WorkingMemory ()
47-
4847 # attribute to store ws connection
4948 self .__ws = ws
5049
50+ # main event loop (for ws messages)
5151 self .__main_loop = main_loop
52+
53+ # get working memory from cache or create a new one
54+ self .load_working_memory_from_cache ()
5255
5356 def __repr__ (self ):
5457 return f"StrayCat(user_id={ self .user_id } )"
5558
56- #def __del__(self):
57- # log.critical(f"StrayCat __del__ called for {self.user_id}")
58- #self.__main_loop = None
59- #self.__ws = None
60- # when the garbage collector deletes the stray, we update working memory in cache
61- # self.update_working_memory_cache()
62- #self.__user_id = None
63- #self.__user_data = None
64- # del self
65-
66-
6759 def __send_ws_json (self , data : Any ):
6860 # Run the corutine in the main event loop in the main thread
6961 # and wait for the result
@@ -101,9 +93,15 @@ def __build_why(self) -> MessageWhy:
10193
10294 return why
10395
96+ def load_working_memory_from_cache (self ):
97+ """Load the working memory from the cache."""
98+ log .warning (f"GET working memory for { self .user_id } " )
99+ self .working_memory = \
100+ self .cache .get_value (f"{ self .user_id } _working_memory" ) or WorkingMemory ()
101+
104102 def update_working_memory_cache (self ):
105103 """Update the working memory in the cache."""
106- log .warning (f"SAVE working memory for { self .user_id } " )
104+ log .critical (f"SAVE { self .user_id } " )
107105 updated_cache_item = CacheItem (f"{ self .user_id } _working_memory" , self .working_memory , - 1 )
108106 self .cache .insert (updated_cache_item )
109107
@@ -392,7 +390,7 @@ def __call__(self, message_dict):
392390 user_message = UserMessage .model_validate (message_dict )
393391 log .info (user_message )
394392
395- ### setup working memory
393+ ### setup working memory for this convo turn
396394 # keeping track of model interactions
397395 self .working_memory .model_interactions = []
398396 # latest user message
@@ -488,8 +486,14 @@ def __call__(self, message_dict):
488486
489487 def run (self , user_message_json , return_message = False ):
490488 try :
489+
490+ # load working memory from cache
491+ self .load_working_memory_from_cache ()
492+ # run main flow
491493 cat_message = self .__call__ (user_message_json )
494+ # save working memory to cache
492495 self .update_working_memory_cache ()
496+
493497 if return_message :
494498 # return the message for HTTP usage
495499 return cat_message
0 commit comments