Skip to content

Commit 60f4348

Browse files
committed
test: _reduce_messages in LangChainAgent
1 parent 89ec76c commit 60f4348

File tree

2 files changed

+75
-24
lines changed

2 files changed

+75
-24
lines changed

src/rai_core/rai/agents/langchain/agent.py

+34-24
Original file line numberDiff line numberDiff line change
@@ -33,25 +33,28 @@ class BaseState(TypedDict):
3333
messages: List[BaseMessage]
3434

3535

36+
newMessageBehaviorType = Literal[
37+
"take_all",
38+
"keep_last",
39+
"queue",
40+
"interuppt_take_all",
41+
"interuppt_keep_last",
42+
]
43+
44+
3645
class LangChainAgent(BaseAgent):
3746
def __init__(
3847
self,
3948
target_connectors: Dict[str, HRIConnector[HRIMessage]],
4049
runnable: Runnable,
4150
state: BaseState | None = None,
42-
new_message_behavior: Literal[
43-
"take_all",
44-
"keep_last",
45-
"queue",
46-
"interuppt_take_all",
47-
"interuppt_keep_last",
48-
] = "interuppt_keep_last",
51+
new_message_behavior: newMessageBehaviorType = "interuppt_keep_last",
4952
max_size: int = 100,
5053
):
5154
super().__init__()
5255
self.logger = logging.getLogger(__name__)
5356
self.agent = runnable
54-
self.new_message_behavior = new_message_behavior
57+
self.new_message_behavior: newMessageBehaviorType = new_message_behavior
5558
self.tracing_callbacks = get_tracing_callbacks()
5659
self.state = state or ReActAgentState(messages=[])
5760
self._langchain_callback = HRICallbackHandler(
@@ -141,26 +144,33 @@ def stop(self):
141144
self.thread = None
142145
self.logger.info("Agent stopped")
143146

144-
def _reduce_messages(self) -> HRIMessage:
145-
text = ""
146-
images = []
147-
audios = []
148-
source_messages = list()
149-
if "take_all" in self.new_message_behavior:
147+
@staticmethod
148+
def _apply_reduction_behavior(
149+
method: newMessageBehaviorType, buffer: Deque
150+
) -> List:
151+
output = list()
152+
if "take_all" in method:
150153
# Take all starting from the oldest
151-
while len(self._received_messages) > 0:
152-
source_messages.append(self._received_messages.popleft())
153-
elif "keep_last" in self.new_message_behavior:
154+
while len(buffer) > 0:
155+
output.append(buffer.popleft())
156+
elif "keep_last" in method:
154157
# Take the recently added message
155-
source_messages.append(self._received_messages.pop())
156-
self._received_messages.clear()
157-
elif self.new_message_behavior == "queue":
158+
output.append(buffer.pop())
159+
buffer.clear()
160+
elif method == "queue":
158161
# Take the first message from the queue. Let other messages wait.
159-
source_messages.append(self._received_messages.popleft())
162+
output.append(buffer.popleft())
160163
else:
161-
raise ValueError(
162-
f"Invalid new_message_behavior: {self.new_message_behavior}"
163-
)
164+
raise ValueError(f"Invalid new_message_behavior: {method}")
165+
return output
166+
167+
def _reduce_messages(self) -> HRIMessage:
168+
text = ""
169+
images = []
170+
audios = []
171+
source_messages = self._apply_reduction_behavior(
172+
self.new_message_behavior, self._received_messages
173+
)
164174
for source_message in source_messages:
165175
text += f"{source_message.text}\n"
166176
images.extend(source_message.images)

tests/agents/test_langchain_agent.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import deque
16+
from typing import List
17+
18+
import pytest
19+
from rai.agents.langchain.agent import LangChainAgent, newMessageBehaviorType
20+
21+
22+
@pytest.mark.parametrize(
23+
"new_message_behavior,in_buffer,out_buffer,output",
24+
[
25+
("take_all", [1, 2, 3], [], [1, 2, 3]),
26+
("keep_last", [1, 2, 3], [], [3]),
27+
("queue", [1, 2, 3], [2, 3], [1]),
28+
("interuppt_take_all", [1, 2, 3], [], [1, 2, 3]),
29+
("interuppt_keep_last", [1, 2, 3], [], [3]),
30+
],
31+
)
32+
def test_reduce_messages(
33+
new_message_behavior: newMessageBehaviorType,
34+
in_buffer: List,
35+
out_buffer: List,
36+
output: List,
37+
):
38+
buffer = deque(in_buffer)
39+
output = LangChainAgent._apply_reduction_behavior(new_message_behavior, buffer)
40+
assert output == output
41+
assert buffer == deque(out_buffer)

0 commit comments

Comments
 (0)