Skip to content

Commit c7b011b

Browse files
authored
refactor+fix: add LangChainAgent abstraction and fix HRI communication (#538)
1 parent c2251f0 commit c7b011b

File tree

13 files changed

+452
-238
lines changed

13 files changed

+452
-238
lines changed

examples/agents/react.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,29 @@
1212
# See the License for the specific language goveself.rning permissions and
1313
# limitations under the License.
1414

15-
from rai.agents import AgentRunner, ReActAgent
16-
from rai.communication.ros2 import ROS2Connector, ROS2Context, ROS2HRIConnector
15+
16+
from rai.agents.langchain.react_agent import ReActAgent
17+
from rai.communication.hri_connector import HRIMessage
18+
from rai.communication.ros2 import ROS2Connector, ROS2Context
19+
from rai.communication.ros2.connectors.hri_connector import ROS2HRIConnector
1720
from rai.tools.ros2 import ROS2Toolkit
1821

1922

2023
@ROS2Context()
2124
def main():
22-
connector = ROS2HRIConnector(sources=["/from_human"], targets=["/to_human"])
2325
ros2_connector = ROS2Connector()
26+
hri_connector = ROS2HRIConnector()
27+
2428
agent = ReActAgent(
25-
connectors={"hri": connector},
29+
target_connectors={
30+
"/to_human": hri_connector,
31+
}, # agnet's output is sent to /to_human ros2 topic
2632
tools=ROS2Toolkit(connector=ros2_connector).get_tools(),
27-
) # type: ignore
28-
runner = AgentRunner([agent])
29-
runner.run_and_wait_for_shutdown()
33+
)
34+
agent.run()
35+
agent(HRIMessage(text="What do you see?"))
36+
agent.wait() # wait for agent to finish
37+
agent.stop()
3038

3139

3240
if __name__ == "__main__":

examples/agents/react_ros2.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language goveself.rning permissions and
13+
# limitations under the License.
14+
15+
16+
from rai.agents import AgentRunner
17+
from rai.agents.langchain.react_agent import ReActAgent
18+
from rai.communication.ros2 import ROS2Connector, ROS2Context
19+
from rai.communication.ros2.connectors.hri_connector import ROS2HRIConnector
20+
from rai.tools.ros2 import ROS2Toolkit
21+
22+
23+
@ROS2Context()
24+
def main():
25+
ros2_connector = ROS2Connector()
26+
hri_connector = ROS2HRIConnector()
27+
28+
agent = ReActAgent(
29+
target_connectors={
30+
"/to_human": hri_connector,
31+
},
32+
tools=ROS2Toolkit(connector=ros2_connector).get_tools(),
33+
)
34+
# Agent will wait for messages published to /from_human ros2 topic
35+
agent.subscribe_source("/from_human", hri_connector)
36+
runner = AgentRunner([agent])
37+
runner.run()
38+
39+
40+
if __name__ == "__main__":
41+
main()

src/rai_core/rai/agents/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from rai.agents.base import BaseAgent
1616
from rai.agents.conversational_agent import create_conversational_agent
17-
from rai.agents.react_agent import ReActAgent
17+
from rai.agents.langchain.react_agent import ReActAgent
1818
from rai.agents.runner import AgentRunner, wait_for_shutdown
1919
from rai.agents.state_based import create_state_based_agent
2020
from rai.agents.tool_runner import ToolRunner
+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import threading
17+
import time
18+
from collections import deque
19+
from concurrent.futures import ThreadPoolExecutor
20+
from typing import Deque, Dict, List, Literal, Optional, TypedDict
21+
22+
from langchain_core.messages import BaseMessage
23+
from langchain_core.runnables import Runnable
24+
25+
from rai.agents.base import BaseAgent
26+
from rai.agents.langchain import HRICallbackHandler
27+
from rai.agents.langchain.runnables import ReActAgentState
28+
from rai.communication.hri_connector import HRIConnector, HRIMessage
29+
from rai.initialization import get_tracing_callbacks
30+
31+
32+
class BaseState(TypedDict):
33+
messages: List[BaseMessage]
34+
35+
36+
newMessageBehaviorType = Literal[
37+
"take_all",
38+
"keep_last",
39+
"queue",
40+
"interrupt_take_all",
41+
"interrupt_keep_last",
42+
]
43+
44+
45+
class LangChainAgent(BaseAgent):
46+
"""
47+
Agent pareametrized by LangGraph runnable that communicates with environment using
48+
`HRIConnector`.
49+
50+
Parameters
51+
----------
52+
target_connectors : Dict[str, HRIConnector[HRIMessage]]
53+
Dict of target_name: connector. Agent will send it's output to these targets using connectors.
54+
runnable : Runnable
55+
LangChain runnable that will be used to generate output.
56+
state : BaseState | None, optional
57+
State to seed the LangChain runnable. If None - empty state is used.
58+
new_message_behavior : newMessageBehaviorType, optional
59+
Describes how to handle new messages and interact with LangChain runnable. There are 2 main options:
60+
1. Agent waits for LangChain runnable to finish processing:
61+
- "take_all": all messages from the queue are concatenated and processed.
62+
- "keep_last": only the last received message is processed, others are dropped.
63+
- "queue": only the first message from the queue is processed, others are kept in the queue.
64+
2. Agent interrupts LangChain runnable:
65+
- "interrupt_take_all": same as "take_all"
66+
- "interrupt_keep_last": same as "keep_last"
67+
max_size : int, optional
68+
Maximum number of messages to keep in the agent's queue. If exceeded, oldest messages are dropped.
69+
70+
71+
Agent can be started using `run` method. Then it is triggered by `HRIMessage`s submited
72+
by `__call__` method. They can be submitted in 2 ways:
73+
- manually using `__call__` method.
74+
- by subscribing to specific source using HRIConnector with `subscribe_source` method.
75+
76+
Agent can be stopped using `stop` method.
77+
78+
Due to asynchronous processing of the Agent, it is adviced to handle it's lifetime
79+
with :py:class:`rai.agents.AgentRunner` when source is subscribed.
80+
81+
Examples:
82+
```python
83+
# ROS2 Example - agent triggered manually
84+
from rai.agents import AgentRunner
85+
hri_connector = ROS2HRIConnector()
86+
runnable = create_langgraph()
87+
agent = LangChainAgent(
88+
target_connectors={"/to_human": hri_connector},
89+
runnable=runnable,
90+
)
91+
agent.run()
92+
agent(HRIMessage(text="Hello!"))
93+
agent.wait()
94+
agent.stop()
95+
96+
# ROS2 Example - triggered by messages on ros2 topic
97+
...
98+
runner = AgentRunner([agent])
99+
runner.run()
100+
agent.source_callback("/from_human", hri_connector)
101+
runner.wait_for_shutdown()
102+
103+
# Agent will act messages published to rai_interfaces.msg.HRIMessage sent to /from_human topic
104+
"""
105+
106+
def __init__(
107+
self,
108+
target_connectors: Dict[str, HRIConnector[HRIMessage]],
109+
runnable: Runnable,
110+
state: BaseState | None = None,
111+
new_message_behavior: newMessageBehaviorType = "interrupt_keep_last",
112+
max_size: int = 100,
113+
):
114+
super().__init__()
115+
self.logger = logging.getLogger(__name__)
116+
self.agent = runnable
117+
self.new_message_behavior: newMessageBehaviorType = new_message_behavior
118+
self.tracing_callbacks = get_tracing_callbacks()
119+
self.state = state or ReActAgentState(messages=[])
120+
self._langchain_callback = HRICallbackHandler(
121+
connectors=target_connectors,
122+
aggregate_chunks=True,
123+
logger=self.logger,
124+
)
125+
126+
self._received_messages: Deque[HRIMessage] = deque()
127+
self._buffer_lock = threading.Lock()
128+
self.max_size = max_size
129+
130+
self._thread: Optional[threading.Thread] = None
131+
self._stop_event = threading.Event()
132+
self._executor = ThreadPoolExecutor(max_workers=1)
133+
self._interrupt_event = threading.Event()
134+
self._agent_ready_event = threading.Event()
135+
136+
def subscribe_source(self, source: str, connector: HRIConnector[HRIMessage]):
137+
connector.register_callback(
138+
source,
139+
self.__call__,
140+
)
141+
142+
def __call__(self, msg: HRIMessage):
143+
with self._buffer_lock:
144+
if (
145+
self.max_size is not None
146+
and len(self._received_messages) >= self.max_size
147+
):
148+
self.logger.warning("Buffer overflow. Dropping olders message")
149+
self._received_messages.popleft()
150+
if "interrupt" in self.new_message_behavior:
151+
self._executor.submit(self._interrupt_agent_and_run)
152+
self.logger.info(f"Received message: {msg}, {type(msg)}")
153+
self._received_messages.append(msg)
154+
155+
def run(self):
156+
if self._thread is not None:
157+
raise RuntimeError("Agent is already running")
158+
self._thread = threading.Thread(target=self._run_loop)
159+
self._thread.start()
160+
self._agent_ready_event.set()
161+
self.logger.info("Agent started")
162+
163+
def ready(self):
164+
return self._agent_ready_event.is_set() and len(self._received_messages) == 0
165+
166+
def wait(self):
167+
while len(self._received_messages) > 0:
168+
time.sleep(0.1)
169+
170+
return self._agent_ready_event.wait()
171+
172+
def _interrupt_agent_and_run(self):
173+
if self.ready():
174+
self.logger.info("Agent is ready. No need to interrupt it.")
175+
return
176+
self.logger.info("Interrupting agent...")
177+
self._interrupt_event.set()
178+
self._agent_ready_event.wait()
179+
self._interrupt_event.clear()
180+
self.logger.info("Interrupting agent: DONE")
181+
182+
def _run_agent(self):
183+
if len(self._received_messages) == 0:
184+
self._agent_ready_event.set()
185+
self.logger.info("Waiting for messages...")
186+
time.sleep(0.5)
187+
return
188+
self._agent_ready_event.clear()
189+
try:
190+
self.logger.info("Running agent...")
191+
reduced_message = self._reduce_messages()
192+
langchain_message = reduced_message.to_langchain()
193+
self.state["messages"].append(langchain_message)
194+
for _ in self.agent.stream(
195+
self.state,
196+
config={
197+
"callbacks": [self._langchain_callback, *self.tracing_callbacks]
198+
},
199+
):
200+
if self._interrupt_event.is_set():
201+
break
202+
finally:
203+
self._agent_ready_event.set()
204+
205+
def _run_loop(self):
206+
while not self._stop_event.is_set():
207+
if self._agent_ready_event.wait(0.01):
208+
self._run_agent()
209+
210+
def stop(self):
211+
self._stop_event.set()
212+
self._interrupt_event.set()
213+
self._agent_ready_event.wait()
214+
if self._thread is not None:
215+
self.logger.info("Stopping the agent. Please wait...")
216+
self._thread.join()
217+
self._thread = None
218+
self.logger.info("Agent stopped")
219+
220+
@staticmethod
221+
def _apply_reduction_behavior(
222+
method: newMessageBehaviorType, buffer: Deque[HRIMessage]
223+
) -> List[HRIMessage]:
224+
output = list()
225+
if "take_all" in method:
226+
# Take all starting from the oldest
227+
while len(buffer) > 0:
228+
output.append(buffer.popleft())
229+
elif "keep_last" in method:
230+
# Take the recently added message
231+
output.append(buffer.pop())
232+
buffer.clear()
233+
elif method == "queue":
234+
# Take the first message from the queue. Let other messages wait.
235+
output.append(buffer.popleft())
236+
else:
237+
raise ValueError(f"Invalid new_message_behavior: {method}")
238+
return output
239+
240+
def _reduce_messages(self) -> HRIMessage:
241+
text = ""
242+
images = []
243+
audios = []
244+
with self._buffer_lock:
245+
source_messages = self._apply_reduction_behavior(
246+
self.new_message_behavior, self._received_messages
247+
)
248+
for source_message in source_messages:
249+
text += f"{source_message.text}\n"
250+
images.extend(source_message.images)
251+
audios.extend(source_message.audios)
252+
return HRIMessage(
253+
text=text,
254+
images=images,
255+
audios=audios,
256+
message_author="human",
257+
)

src/rai_core/rai/agents/langchain/callback.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
import threading
17-
from typing import List, Optional
17+
from typing import Dict, List, Optional
1818
from uuid import UUID
1919

2020
from langchain_core.callbacks import BaseCallbackHandler
@@ -27,7 +27,7 @@
2727
class HRICallbackHandler(BaseCallbackHandler):
2828
def __init__(
2929
self,
30-
connectors: dict[str, HRIConnector[HRIMessage]],
30+
connectors: Dict[str, HRIConnector[HRIMessage]],
3131
aggregate_chunks: bool = False,
3232
splitting_chars: Optional[List[str]] = None,
3333
max_buffer_size: int = 200,
@@ -47,21 +47,20 @@ def _should_split(self, token: str) -> bool:
4747
return token in self.splitting_chars
4848

4949
def _send_all_targets(self, tokens: str, done: bool = False):
50-
self.logger.info(
51-
f"Sending {len(tokens)} tokens to {len(self.connectors)} connectors"
52-
)
53-
for connector_name, connector in self.connectors.items():
50+
for target, connector in self.connectors.items():
51+
self.logger.info(f"Sending {len(tokens)} tokens to target: {target}")
5452
try:
55-
connector.send_all_targets(
53+
to_send: HRIMessage = connector.build_message(
5654
AIMessage(content=tokens),
5755
self.current_conversation_id,
5856
self.current_chunk_id,
5957
done,
6058
)
61-
self.logger.debug(f"Sent {len(tokens)} tokens to {connector_name}")
59+
connector.send_message(to_send, target)
60+
self.logger.debug(f"Sent {len(tokens)} tokens to hri_connector.")
6261
except Exception as e:
6362
self.logger.error(
64-
f"Failed to send {len(tokens)} tokens to {connector_name}: {e}"
63+
f"Failed to send {len(tokens)} tokens to hri_connector: {e}"
6564
)
6665

6766
def on_llm_new_token(self, token: str, *, run_id: UUID, **kwargs):

0 commit comments

Comments
 (0)