Skip to content

Commit 8d6b5ef

Browse files
committed
Merge branch 'samadpls-ref/move-history-methods-to-working-memory' into develop
2 parents 497b15a + 05c73e4 commit 8d6b5ef

File tree

7 files changed

+128
-48
lines changed

7 files changed

+128
-48
lines changed

core/cat/agents/main_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def format_agent_input(self, stray):
116116

117117
# format conversation history to be inserted in the prompt
118118
# TODOV2: take away
119-
conversation_history_formatted_content = stray.stringify_chat_history()
119+
conversation_history_formatted_content = stray.working_memory.stringify_chat_history()
120120

121121
return BaseModelDict(**{
122122
"episodic_memory": episodic_memory_formatted_content,

core/cat/agents/memory_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def execute(self, stray, prompt_prefix, prompt_suffix) -> AgentOutput:
2424
SystemMessagePromptTemplate.from_template(
2525
template=sys_prompt
2626
),
27-
*(stray.langchainfy_chat_history()),
27+
*(stray.working_memory.langchainfy_chat_history()),
2828
]
2929
)
3030

core/cat/agents/procedures_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def execute_chain(self, stray, procedures_prompt_template, allowed_procedures) -
116116
SystemMessagePromptTemplate.from_template(
117117
template=procedures_prompt_template
118118
),
119-
*(stray.langchainfy_chat_history()),
119+
*(stray.working_memory.langchainfy_chat_history()),
120120
]
121121
)
122122

core/cat/experimental/form/cat_form.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def extract(self):
216216
return output_model
217217

218218
def extraction_prompt(self):
219-
history = self.cat.stringify_chat_history()
219+
history = self.cat.working_memory.stringify_chat_history()
220220

221221
# JSON structure
222222
# BaseModel.__fields__['my_field'].type_

core/cat/looking_glass/stray_cat.py

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -551,50 +551,13 @@ def classify(
551551
# set 0.5 as threshold - let's see if it works properly
552552
return best_label if score < 0.5 else None
553553

554-
def stringify_chat_history(self, latest_n: int = 5) -> str:
555-
"""Serialize chat history.
556-
Converts to text the recent conversation turns.
557-
558-
Parameters
559-
----------
560-
latest_n : int
561-
Hoe many latest turns to stringify.
562-
563-
Returns
564-
-------
565-
history : str
566-
String with recent conversation turns.
567-
568-
Notes
569-
-----
570-
Such context is placed in the `agent_prompt_suffix` in the place held by {chat_history}.
571-
572-
The chat history is a dictionary with keys::
573-
'who': the name of who said the utterance;
574-
'message': the utterance.
575-
576-
"""
577-
578-
history = self.working_memory.history[-latest_n:]
579-
580-
history_string = ""
581-
for turn in history:
582-
history_string += f"\n - {turn.who}: {turn.text}"
583-
584-
return history_string
585-
586-
def langchainfy_chat_history(self, latest_n: int = 5) -> List[BaseMessage]:
587-
588-
chat_history = self.working_memory.history[-latest_n:]
589-
recent_history = chat_history[-latest_n:]
590-
langchain_chat_history = []
591-
592-
for message in recent_history:
593-
langchain_chat_history.append(
594-
message.langchainfy()
595-
)
596-
597-
return langchain_chat_history
554+
def langchainfy_chat_history(self, latest_n: int = 10) -> List[BaseMessage]:
555+
"""Redirects to WorkingMemory.langchainfy_chat_history. Will be removed from this class in v2."""
556+
return self.working_memory.langchainfy_chat_history(latest_n)
557+
558+
def stringify_chat_history(self, latest_n: int = 10) -> str:
559+
"""Redirects to WorkingMemory.stringify_chat_history. Will be removed from this class in v2."""
560+
return self.working_memory.stringify_chat_history(latest_n)
598561

599562
async def close_connection(self):
600563
if self.__ws:

core/cat/memory/working_memory.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Optional
2+
from langchain_core.messages import BaseMessage
23

34
from cat.convo.messages import Role, ConversationMessage, UserMessage, CatMessage
45
from cat.convo.model_interactions import ModelInteraction
@@ -95,3 +96,51 @@ def update_history(self, message: ConversationMessage):
9596
"""
9697
self.history.append(message)
9798

99+
100+
def stringify_chat_history(self, latest_n: int = 10) -> str:
101+
"""Serialize chat history.
102+
Converts to text the recent conversation turns.
103+
Useful for retrocompatibility with old non-chat models, and to easily insert convo into a prompt without using dedicated objects and libraries.
104+
105+
Parameters
106+
----------
107+
latest_n : int
108+
How many latest turns to stringify.
109+
110+
Returns
111+
-------
112+
history : str
113+
String with recent conversation turns.
114+
"""
115+
116+
history = self.history[-latest_n:]
117+
118+
history_string = ""
119+
for turn in history:
120+
history_string += f"\n - {turn.who}: {turn.text}"
121+
122+
return history_string
123+
124+
def langchainfy_chat_history(self, latest_n: int = 10) -> List[BaseMessage]:
125+
"""Convert chat history in working memory to langchain objects.
126+
127+
Parameters
128+
----------
129+
latest_n : int
130+
How many latest turns to convert.
131+
132+
Returns
133+
-------
134+
history : List[BaseMessage]
135+
List of langchain HumanMessage / AIMessage.
136+
"""
137+
chat_history = self.history[-latest_n:]
138+
recent_history = chat_history[-latest_n:]
139+
langchain_chat_history = []
140+
141+
for message in recent_history:
142+
langchain_chat_history.append(
143+
message.langchainfy()
144+
)
145+
146+
return langchain_chat_history
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from langchain_core.messages import AIMessage, HumanMessage
2+
3+
from cat.convo.messages import Role, ConversationMessage, UserMessage, CatMessage
4+
from cat.memory.working_memory import WorkingMemory
5+
6+
def create_working_memory_with_convo_history():
7+
"""Utility to create a working memory and populate its convo history."""
8+
9+
working_memory = WorkingMemory()
10+
human_message = UserMessage(user_id="123", who="Human", text="Hi")
11+
working_memory.update_history(human_message)
12+
cat_message = CatMessage(user_id="123", who="AI", text="Meow")
13+
working_memory.update_history(cat_message)
14+
return working_memory
15+
16+
def test_create_working_memory():
17+
18+
wm = WorkingMemory()
19+
assert wm.history == []
20+
assert wm.user_message_json == None
21+
assert wm.active_form == None
22+
assert wm.recall_query == ""
23+
assert wm.episodic_memories == []
24+
assert wm.declarative_memories == []
25+
assert wm.procedural_memories == []
26+
assert wm.model_interactions == []
27+
28+
29+
def test_update_history():
30+
31+
wm = create_working_memory_with_convo_history()
32+
33+
assert len(wm.history) == 2
34+
35+
assert isinstance(wm.history[0], UserMessage)
36+
assert wm.history[0].who == "Human"
37+
assert wm.history[0].role == Role.Human
38+
assert wm.history[0].text == "Hi"
39+
40+
assert isinstance(wm.history[1], CatMessage)
41+
assert wm.history[1].who == "AI"
42+
assert wm.history[1].role == Role.AI
43+
assert wm.history[1].text == "Meow"
44+
45+
46+
def test_stringify_chat_history():
47+
48+
wm = create_working_memory_with_convo_history()
49+
assert wm.stringify_chat_history() == "\n - Human: Hi\n - AI: Meow"
50+
51+
52+
def test_langchainfy_chat_history():
53+
54+
wm = create_working_memory_with_convo_history()
55+
langchain_convo = wm.langchainfy_chat_history()
56+
57+
assert len(langchain_convo) == len(wm.history)
58+
59+
assert isinstance(langchain_convo[0], HumanMessage)
60+
assert langchain_convo[0].name == "Human"
61+
assert isinstance(langchain_convo[0].content, list)
62+
assert langchain_convo[0].content[0] == {"type": "text", "text": "Hi"}
63+
64+
assert isinstance(langchain_convo[1], AIMessage)
65+
assert langchain_convo[1].name == "AI"
66+
assert langchain_convo[1].content == "Meow"
67+
68+
# TODOV2: add tests for multimodal messages!

0 commit comments

Comments
 (0)