|
| 1 | +# Copyright (C) 2025 Robotec.AI |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import logging |
| 16 | +import threading |
| 17 | +import time |
| 18 | +from collections import deque |
| 19 | +from concurrent.futures import ThreadPoolExecutor |
| 20 | +from typing import Deque, Dict, List, Literal, Optional, TypedDict |
| 21 | + |
| 22 | +from langchain_core.messages import BaseMessage |
| 23 | +from langchain_core.runnables import Runnable |
| 24 | + |
| 25 | +from rai.agents.base import BaseAgent |
| 26 | +from rai.agents.langchain import HRICallbackHandler |
| 27 | +from rai.agents.langchain.runnables import ReActAgentState |
| 28 | +from rai.communication.hri_connector import HRIConnector, HRIMessage |
| 29 | +from rai.initialization import get_tracing_callbacks |
| 30 | + |
| 31 | + |
| 32 | +class BaseState(TypedDict): |
| 33 | + messages: List[BaseMessage] |
| 34 | + |
| 35 | + |
| 36 | +newMessageBehaviorType = Literal[ |
| 37 | + "take_all", |
| 38 | + "keep_last", |
| 39 | + "queue", |
| 40 | + "interrupt_take_all", |
| 41 | + "interrupt_keep_last", |
| 42 | +] |
| 43 | + |
| 44 | + |
| 45 | +class LangChainAgent(BaseAgent): |
| 46 | + """ |
| 47 | + Agent pareametrized by LangGraph runnable that communicates with environment using |
| 48 | + `HRIConnector`. |
| 49 | +
|
| 50 | + Parameters |
| 51 | + ---------- |
| 52 | + target_connectors : Dict[str, HRIConnector[HRIMessage]] |
| 53 | + Dict of target_name: connector. Agent will send it's output to these targets using connectors. |
| 54 | + runnable : Runnable |
| 55 | + LangChain runnable that will be used to generate output. |
| 56 | + state : BaseState | None, optional |
| 57 | + State to seed the LangChain runnable. If None - empty state is used. |
| 58 | + new_message_behavior : newMessageBehaviorType, optional |
| 59 | + Describes how to handle new messages and interact with LangChain runnable. There are 2 main options: |
| 60 | + 1. Agent waits for LangChain runnable to finish processing: |
| 61 | + - "take_all": all messages from the queue are concatenated and processed. |
| 62 | + - "keep_last": only the last received message is processed, others are dropped. |
| 63 | + - "queue": only the first message from the queue is processed, others are kept in the queue. |
| 64 | + 2. Agent interrupts LangChain runnable: |
| 65 | + - "interrupt_take_all": same as "take_all" |
| 66 | + - "interrupt_keep_last": same as "keep_last" |
| 67 | + max_size : int, optional |
| 68 | + Maximum number of messages to keep in the agent's queue. If exceeded, oldest messages are dropped. |
| 69 | +
|
| 70 | +
|
| 71 | + Agent can be started using `run` method. Then it is triggered by `HRIMessage`s submited |
| 72 | + by `__call__` method. They can be submitted in 2 ways: |
| 73 | + - manually using `__call__` method. |
| 74 | + - by subscribing to specific source using HRIConnector with `subscribe_source` method. |
| 75 | +
|
| 76 | + Agent can be stopped using `stop` method. |
| 77 | +
|
| 78 | + Due to asynchronous processing of the Agent, it is adviced to handle it's lifetime |
| 79 | + with :py:class:`rai.agents.AgentRunner` when source is subscribed. |
| 80 | +
|
| 81 | + Examples: |
| 82 | + ```python |
| 83 | + # ROS2 Example - agent triggered manually |
| 84 | + from rai.agents import AgentRunner |
| 85 | + hri_connector = ROS2HRIConnector() |
| 86 | + runnable = create_langgraph() |
| 87 | + agent = LangChainAgent( |
| 88 | + target_connectors={"/to_human": hri_connector}, |
| 89 | + runnable=runnable, |
| 90 | + ) |
| 91 | + agent.run() |
| 92 | + agent(HRIMessage(text="Hello!")) |
| 93 | + agent.wait() |
| 94 | + agent.stop() |
| 95 | +
|
| 96 | + # ROS2 Example - triggered by messages on ros2 topic |
| 97 | + ... |
| 98 | + runner = AgentRunner([agent]) |
| 99 | + runner.run() |
| 100 | + agent.source_callback("/from_human", hri_connector) |
| 101 | + runner.wait_for_shutdown() |
| 102 | +
|
| 103 | + # Agent will act messages published to rai_interfaces.msg.HRIMessage sent to /from_human topic |
| 104 | + """ |
| 105 | + |
| 106 | + def __init__( |
| 107 | + self, |
| 108 | + target_connectors: Dict[str, HRIConnector[HRIMessage]], |
| 109 | + runnable: Runnable, |
| 110 | + state: BaseState | None = None, |
| 111 | + new_message_behavior: newMessageBehaviorType = "interrupt_keep_last", |
| 112 | + max_size: int = 100, |
| 113 | + ): |
| 114 | + super().__init__() |
| 115 | + self.logger = logging.getLogger(__name__) |
| 116 | + self.agent = runnable |
| 117 | + self.new_message_behavior: newMessageBehaviorType = new_message_behavior |
| 118 | + self.tracing_callbacks = get_tracing_callbacks() |
| 119 | + self.state = state or ReActAgentState(messages=[]) |
| 120 | + self._langchain_callback = HRICallbackHandler( |
| 121 | + connectors=target_connectors, |
| 122 | + aggregate_chunks=True, |
| 123 | + logger=self.logger, |
| 124 | + ) |
| 125 | + |
| 126 | + self._received_messages: Deque[HRIMessage] = deque() |
| 127 | + self._buffer_lock = threading.Lock() |
| 128 | + self.max_size = max_size |
| 129 | + |
| 130 | + self._thread: Optional[threading.Thread] = None |
| 131 | + self._stop_event = threading.Event() |
| 132 | + self._executor = ThreadPoolExecutor(max_workers=1) |
| 133 | + self._interrupt_event = threading.Event() |
| 134 | + self._agent_ready_event = threading.Event() |
| 135 | + |
| 136 | + def subscribe_source(self, source: str, connector: HRIConnector[HRIMessage]): |
| 137 | + connector.register_callback( |
| 138 | + source, |
| 139 | + self.__call__, |
| 140 | + ) |
| 141 | + |
| 142 | + def __call__(self, msg: HRIMessage): |
| 143 | + with self._buffer_lock: |
| 144 | + if ( |
| 145 | + self.max_size is not None |
| 146 | + and len(self._received_messages) >= self.max_size |
| 147 | + ): |
| 148 | + self.logger.warning("Buffer overflow. Dropping olders message") |
| 149 | + self._received_messages.popleft() |
| 150 | + if "interrupt" in self.new_message_behavior: |
| 151 | + self._executor.submit(self._interrupt_agent_and_run) |
| 152 | + self.logger.info(f"Received message: {msg}, {type(msg)}") |
| 153 | + self._received_messages.append(msg) |
| 154 | + |
| 155 | + def run(self): |
| 156 | + if self._thread is not None: |
| 157 | + raise RuntimeError("Agent is already running") |
| 158 | + self._thread = threading.Thread(target=self._run_loop) |
| 159 | + self._thread.start() |
| 160 | + self._agent_ready_event.set() |
| 161 | + self.logger.info("Agent started") |
| 162 | + |
| 163 | + def ready(self): |
| 164 | + return self._agent_ready_event.is_set() and len(self._received_messages) == 0 |
| 165 | + |
| 166 | + def wait(self): |
| 167 | + while len(self._received_messages) > 0: |
| 168 | + time.sleep(0.1) |
| 169 | + |
| 170 | + return self._agent_ready_event.wait() |
| 171 | + |
| 172 | + def _interrupt_agent_and_run(self): |
| 173 | + if self.ready(): |
| 174 | + self.logger.info("Agent is ready. No need to interrupt it.") |
| 175 | + return |
| 176 | + self.logger.info("Interrupting agent...") |
| 177 | + self._interrupt_event.set() |
| 178 | + self._agent_ready_event.wait() |
| 179 | + self._interrupt_event.clear() |
| 180 | + self.logger.info("Interrupting agent: DONE") |
| 181 | + |
| 182 | + def _run_agent(self): |
| 183 | + if len(self._received_messages) == 0: |
| 184 | + self._agent_ready_event.set() |
| 185 | + self.logger.info("Waiting for messages...") |
| 186 | + time.sleep(0.5) |
| 187 | + return |
| 188 | + self._agent_ready_event.clear() |
| 189 | + try: |
| 190 | + self.logger.info("Running agent...") |
| 191 | + reduced_message = self._reduce_messages() |
| 192 | + langchain_message = reduced_message.to_langchain() |
| 193 | + self.state["messages"].append(langchain_message) |
| 194 | + for _ in self.agent.stream( |
| 195 | + self.state, |
| 196 | + config={ |
| 197 | + "callbacks": [self._langchain_callback, *self.tracing_callbacks] |
| 198 | + }, |
| 199 | + ): |
| 200 | + if self._interrupt_event.is_set(): |
| 201 | + break |
| 202 | + finally: |
| 203 | + self._agent_ready_event.set() |
| 204 | + |
| 205 | + def _run_loop(self): |
| 206 | + while not self._stop_event.is_set(): |
| 207 | + if self._agent_ready_event.wait(0.01): |
| 208 | + self._run_agent() |
| 209 | + |
| 210 | + def stop(self): |
| 211 | + self._stop_event.set() |
| 212 | + self._interrupt_event.set() |
| 213 | + self._agent_ready_event.wait() |
| 214 | + if self._thread is not None: |
| 215 | + self.logger.info("Stopping the agent. Please wait...") |
| 216 | + self._thread.join() |
| 217 | + self._thread = None |
| 218 | + self.logger.info("Agent stopped") |
| 219 | + |
| 220 | + @staticmethod |
| 221 | + def _apply_reduction_behavior( |
| 222 | + method: newMessageBehaviorType, buffer: Deque[HRIMessage] |
| 223 | + ) -> List[HRIMessage]: |
| 224 | + output = list() |
| 225 | + if "take_all" in method: |
| 226 | + # Take all starting from the oldest |
| 227 | + while len(buffer) > 0: |
| 228 | + output.append(buffer.popleft()) |
| 229 | + elif "keep_last" in method: |
| 230 | + # Take the recently added message |
| 231 | + output.append(buffer.pop()) |
| 232 | + buffer.clear() |
| 233 | + elif method == "queue": |
| 234 | + # Take the first message from the queue. Let other messages wait. |
| 235 | + output.append(buffer.popleft()) |
| 236 | + else: |
| 237 | + raise ValueError(f"Invalid new_message_behavior: {method}") |
| 238 | + return output |
| 239 | + |
| 240 | + def _reduce_messages(self) -> HRIMessage: |
| 241 | + text = "" |
| 242 | + images = [] |
| 243 | + audios = [] |
| 244 | + with self._buffer_lock: |
| 245 | + source_messages = self._apply_reduction_behavior( |
| 246 | + self.new_message_behavior, self._received_messages |
| 247 | + ) |
| 248 | + for source_message in source_messages: |
| 249 | + text += f"{source_message.text}\n" |
| 250 | + images.extend(source_message.images) |
| 251 | + audios.extend(source_message.audios) |
| 252 | + return HRIMessage( |
| 253 | + text=text, |
| 254 | + images=images, |
| 255 | + audios=audios, |
| 256 | + message_author="human", |
| 257 | + ) |
0 commit comments