diff --git a/docs/developer_guide.md b/docs/developer_guide.md index 1154bd4c4..db2f8f949 100644 --- a/docs/developer_guide.md +++ b/docs/developer_guide.md @@ -132,7 +132,7 @@ Choose the conversational_agent when beginning with RAI, and consider upgrading from myrobot import robot from rai.agents.conversational_agent import create_conversational_agent -from rai.agents.state_based import create_state_based_agent +from rai.agents.langchain import create_state_based_runnable from rai import get_llm_model SYSTEM_PROMPT = "You are a robot with interfaces..." @@ -144,7 +144,7 @@ tools = [pick_up_object, scan_object, SayTool(robot=robot)] conversational_agent = create_conversational_agent( llm=llm, tools=tools, system_prompt=SYSTEM_PROMPT ) -state_based_agent = create_state_based_agent( +state_based_agent = create_state_based_runnable( llm=llm, state_retriever=state_retriever, tools=tools ) diff --git a/examples/agents/state_based.py b/examples/agents/state_based.py new file mode 100644 index 000000000..abe7273e7 --- /dev/null +++ b/examples/agents/state_based.py @@ -0,0 +1,57 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language goveself.rning permissions and +# limitations under the License. + +from rai.agents import wait_for_shutdown +from rai.agents.langchain import StateBasedConfig +from rai.agents.ros2 import ROS2StateBasedAgent +from rai.aggregators.ros2 import ( + ROS2ImgVLMDiffAggregator, + ROS2LogsAggregator, +) +from rai.communication.ros2 import ( + ROS2Connector, + ROS2Context, + ROS2HRIConnector, +) +from rai.tools.ros2.generic.toolkit import ROS2Toolkit + + +@ROS2Context() +def main(): + hri_connector = ROS2HRIConnector() + ros2_connector = ROS2Connector() + + config = StateBasedConfig( + aggregators={ + ("/camera/camera/color/image_raw", "sensor_msgs/msg/Image"): [ + ROS2ImgVLMDiffAggregator() + ], + "/rosout": [ + ROS2LogsAggregator() + ], # if msg_type is not provided, topic has to exist + } + ) + + agent = ROS2StateBasedAgent( + config=config, + target_connectors={"to_human": hri_connector}, + tools=ROS2Toolkit(connector=ros2_connector).get_tools(), + ) + agent.subscribe_source("/from_human", hri_connector) + agent.run() + wait_for_shutdown([agent]) + + +if __name__ == "__main__": + main() diff --git a/src/rai_core/rai/agents/__init__.py b/src/rai_core/rai/agents/__init__.py index b28c98661..c21fc978b 100644 --- a/src/rai_core/rai/agents/__init__.py +++ b/src/rai_core/rai/agents/__init__.py @@ -16,7 +16,6 @@ from rai.agents.conversational_agent import create_conversational_agent from rai.agents.langchain.react_agent import ReActAgent from rai.agents.runner import AgentRunner, wait_for_shutdown -from rai.agents.state_based import create_state_based_agent from rai.agents.tool_runner import ToolRunner __all__ = [ @@ -25,6 +24,5 @@ "ReActAgent", "ToolRunner", "create_conversational_agent", - "create_state_based_agent", "wait_for_shutdown", ] diff --git a/src/rai_core/rai/agents/langchain/__init__.py b/src/rai_core/rai/agents/langchain/__init__.py index 40ce9d9f7..e8540294b 100644 --- a/src/rai_core/rai/agents/langchain/__init__.py +++ b/src/rai_core/rai/agents/langchain/__init__.py @@ -12,7 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .agent import BaseState, LangChainAgent, newMessageBehaviorType from .callback import HRICallbackHandler -from .runnables import create_react_runnable +from .react_agent import ReActAgent +from .runnables import ( + ReActAgentState, + create_react_runnable, + create_state_based_runnable, +) +from .state_based_agent import BaseStateBasedAgent, StateBasedConfig -__all__ = ["HRICallbackHandler", "create_react_runnable"] +__all__ = [ + "BaseState", + "BaseStateBasedAgent", + "HRICallbackHandler", + "LangChainAgent", + "ReActAgent", + "ReActAgentState", + "StateBasedConfig", + "create_react_runnable", + "create_state_based_runnable", + "newMessageBehaviorType", +] diff --git a/src/rai_core/rai/agents/langchain/agent.py b/src/rai_core/rai/agents/langchain/agent.py index 1e4d2632f..a4b15c1c5 100644 --- a/src/rai_core/rai/agents/langchain/agent.py +++ b/src/rai_core/rai/agents/langchain/agent.py @@ -22,12 +22,13 @@ from langchain_core.messages import BaseMessage from langchain_core.runnables import Runnable -from rai.agents.base import BaseAgent -from rai.agents.langchain import HRICallbackHandler -from rai.agents.langchain.runnables import ReActAgentState from rai.communication.hri_connector import HRIConnector, HRIMessage from rai.initialization import get_tracing_callbacks +from ..base import BaseAgent +from .callback import HRICallbackHandler +from .runnables import ReActAgentState + class BaseState(TypedDict): messages: List[BaseMessage] @@ -208,6 +209,7 @@ def _run_loop(self): self._run_agent() def stop(self): + """Stop the agent's execution loop.""" self._stop_event.set() self._interrupt_event.set() self._agent_ready_event.wait() @@ -216,6 +218,7 @@ def stop(self): self._thread.join() self._thread = None self.logger.info("Agent stopped") + self._stop_event.clear() @staticmethod def _apply_reduction_behavior( diff --git a/src/rai_core/rai/agents/langchain/react_agent.py b/src/rai_core/rai/agents/langchain/react_agent.py index e96baea38..bdc84d5b4 100644 --- a/src/rai_core/rai/agents/langchain/react_agent.py +++ b/src/rai_core/rai/agents/langchain/react_agent.py @@ -17,11 +17,11 @@ from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool -from rai.agents.langchain import create_react_runnable -from rai.agents.langchain.agent import LangChainAgent -from rai.agents.langchain.runnables import ReActAgentState from rai.communication.hri_connector import HRIConnector, HRIMessage +from .agent import LangChainAgent +from .runnables import ReActAgentState, create_react_runnable + class ReActAgent(LangChainAgent): def __init__( diff --git a/src/rai_core/rai/agents/langchain/runnables.py b/src/rai_core/rai/agents/langchain/runnables.py index 5d1d1067d..d5d824ef4 100644 --- a/src/rai_core/rai/agents/langchain/runnables.py +++ b/src/rai_core/rai/agents/langchain/runnables.py @@ -12,18 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from functools import partial -from typing import List, Optional, TypedDict, cast +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + TypedDict, + cast, +) from langchain_core.language_models import BaseChatModel -from langchain_core.messages import BaseMessage, SystemMessage +from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool from langgraph.graph import START, StateGraph from langgraph.prebuilt.tool_node import tools_condition -from rai.agents.tool_runner import ToolRunner from rai.initialization import get_llm_model +from rai.messages import HumanMultimodalMessage + +from ..tool_runner import ToolRunner class ReActAgentState(TypedDict): @@ -112,3 +123,50 @@ def create_react_runnable( # Compile the graph return graph.compile() + + +def retriever_wrapper( + state_retriever: Callable[[], Dict[str, HumanMessage | HumanMultimodalMessage]], + state: ReActAgentState, +): + """This wrapper is used to put state messages into LLM context""" + for source, message in state_retriever().items(): + message.content = f"{source}: {message.content}" + logging.getLogger("state_retriever").debug( + f"Adding state message:\n{message.pretty_repr()}" + ) + state["messages"].append(message) + return state + + +def create_state_based_runnable( + llm: Optional[BaseChatModel] = None, + tools: Optional[List[BaseTool]] = None, + system_prompt: Optional[str] = None, + state_retriever: Optional[Callable[[], Dict[str, Any]]] = None, +) -> Runnable[ReActAgentState, ReActAgentState]: + if llm is None: + llm = get_llm_model("complex_model", streaming=True) + graph = StateGraph(ReActAgentState) + graph.add_edge(START, "state_retriever") + graph.add_edge("state_retriever", "llm") + graph.add_conditional_edges( + "llm", + tools_condition, + ) + graph.add_edge("tools", "state_retriever") + + if state_retriever is None: + state_retriever = lambda: {} + + graph.add_node("state_retriever", partial(retriever_wrapper, state_retriever)) + + if tools is None: + tools = [] + bound_llm = cast(BaseChatModel, llm.bind_tools(tools)) + graph.add_node("llm", partial(llm_node, bound_llm, system_prompt)) + + tool_runner = ToolRunner(tools) + graph.add_node("tools", tool_runner) + + return graph.compile() diff --git a/src/rai_core/rai/agents/langchain/state_based_agent.py b/src/rai_core/rai/agents/langchain/state_based_agent.py new file mode 100644 index 000000000..236e9d758 --- /dev/null +++ b/src/rai_core/rai/agents/langchain/state_based_agent.py @@ -0,0 +1,188 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Dict, List, Optional, Tuple, Union + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.tools import BaseTool +from pydantic import BaseModel, ConfigDict, Field + +from rai.aggregators import BaseAggregator +from rai.communication.base_connector import BaseConnector +from rai.communication.hri_connector import HRIConnector, HRIMessage +from rai.messages.multimodal import HumanMultimodalMessage + +from .agent import LangChainAgent, newMessageBehaviorType +from .runnables import ReActAgentState, create_state_based_runnable + + +class StateBasedConfig(BaseModel): + aggregators: Dict[Union[str, Tuple[str, str]], List[BaseAggregator]] = Field( + description="Dict of topic : aggregator or (topic, msg_type) : aggragator" + ) + time_interval: float = Field(default=5.0) + max_workers: int = 8 + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + +class BaseStateBasedAgent(LangChainAgent, ABC): + """ + Agent that runs aggregators (config.aggregators) every config.time_interval seconds. + Aggregators are registered to their sources using + :py:class:`~rai.communication.ros2.connectors.ROS2Connector` + + Output from aggragators is called `state`. Such state is saved and can be + retrieved by `get_state` method. + + In `StateBaseAgent`, state is added to LLM history. For more details about the LLM + agent see :py:func:`~rai.agents.langchain.runnables.create_state_based_runnable` + """ + + def __init__( + self, + config: StateBasedConfig, + target_connectors: Dict[str, HRIConnector[HRIMessage]], + llm: Optional[BaseChatModel] = None, + tools: Optional[List[BaseTool]] = None, + state: Optional[ReActAgentState] = None, + system_prompt: Optional[str] = None, + new_message_behavior: newMessageBehaviorType = "interrupt_keep_last", + max_size: int = 100, + ): + runnable = create_state_based_runnable( + llm=llm, + tools=tools, + system_prompt=system_prompt, + state_retriever=self.get_state, + ) + super().__init__( + target_connectors=target_connectors, + runnable=runnable, + state=state, + new_message_behavior=new_message_behavior, + max_size=max_size, + ) + self.config = config + + self._aggregation_results: Dict[str, HumanMessage | HumanMultimodalMessage] = ( + dict() + ) + self._aggregation_thread: threading.Thread | None = None + + self._registered_callbacks = set() + self._connector = self.setup_connector() + self._configure_state_sources() + + @abstractmethod + def setup_connector(self) -> BaseConnector: + pass + + def _configure_state_sources(self): + for source, aggregators in self.config.aggregators.items(): + if isinstance(source, tuple): + source, msg_type = source + else: + msg_type = None + for aggregator in aggregators: + callback_id = self._connector.register_callback( + source, aggregator, raw=True, msg_type=msg_type + ) + self._registered_callbacks.add(callback_id) + + def run(self): + super().run() + self._aggregation_thread = threading.Thread(target=self._run_state_loop) + self._aggregation_thread.start() + + def get_state(self) -> Dict[str, HumanMessage | HumanMultimodalMessage]: + """Returns output for all aggregators""" + return self._aggregation_results + + def _run_state_loop(self): + """Runs aggregation on collected data""" + while not self._stop_event.is_set(): + ts = time.perf_counter() + self.logger.debug("Starting aggregation interval") + self._on_aggregation_interval() + elapsed_time = time.perf_counter() - ts + self.logger.debug(f"Aggregation done in: {elapsed_time:.2f}s") + if elapsed_time > self.config.time_interval: + self.logger.warning( + "State aggregation time interval exceeded. Expected " + f"{self.config.time_interval:.2f}s, got {elapsed_time:.2f}s. Consider " + f"increasing {self.__class__.__name__}.config.time_interval." + ) + time.sleep(max(0, self.config.time_interval - (elapsed_time))) + + def _on_aggregation_interval(self): + """Runs aggregation on collected data""" + + def process_aggregator( + source: str, aggregator: BaseAggregator + ) -> Tuple[str, BaseMessage | None]: + self.logger.info( + f"Running aggregator: {aggregator}(source={source}) on {len(aggregator.get_buffer())} messages" + ) + ts = time.perf_counter() + + output = aggregator.get() + + self.logger.debug( + f'Aggregator "{aggregator}(source={source})" done in {time.perf_counter() - ts:.2f}s' + ) + return source, output + + with ThreadPoolExecutor(max_workers=self.config.max_workers) as executor: + futures = list() + for source, aggregators in self.config.aggregators.items(): + for aggregator in aggregators: + future = executor.submit(process_aggregator, source, aggregator) + futures.append(future) + + for future in as_completed(futures): + try: + source, output = future.result() + except Exception as e: + self.logger.error(f"Aggregator crashed: {e}") + continue + + if output is None: + continue + self._aggregation_results[source] = output + + def stop(self): + """Stop the agent's execution loop.""" + self._stop_event.set() + self._interrupt_event.set() + self._agent_ready_event.wait() + if self._thread is not None: + self.logger.info("Stopping the agent. Please wait...") + self._thread.join() + self._thread = None + if self._aggregation_thread is not None: + self._aggregation_thread.join() + self._aggregation_thread = None + for callback_id in self._registered_callbacks: + self._connector.unregister_callback(callback_id) + self._stop_event.clear() + self._connector.shutdown() + self.logger.info("Agent stopped") diff --git a/src/rai_core/rai/agents/ros2/__init__.py b/src/rai_core/rai/agents/ros2/__init__.py new file mode 100644 index 000000000..283f2732c --- /dev/null +++ b/src/rai_core/rai/agents/ros2/__init__.py @@ -0,0 +1,16 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .state_based_agent import ROS2StateBasedAgent + +__all__ = ["ROS2StateBasedAgent"] diff --git a/src/rai_core/rai/agents/ros2/state_based_agent.py b/src/rai_core/rai/agents/ros2/state_based_agent.py new file mode 100644 index 000000000..9ca503241 --- /dev/null +++ b/src/rai_core/rai/agents/ros2/state_based_agent.py @@ -0,0 +1,22 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rai.communication.ros2 import ROS2Connector + +from ..langchain import BaseStateBasedAgent + + +class ROS2StateBasedAgent(BaseStateBasedAgent): + def setup_connector(self): + return ROS2Connector() diff --git a/src/rai_core/rai/agents/state_based.py b/src/rai_core/rai/agents/state_based.py deleted file mode 100644 index d9e6fb7b8..000000000 --- a/src/rai_core/rai/agents/state_based.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (C) 2024 Robotec.AI -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import time -from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional, TypedDict, Union - -from langchain.chat_models.base import BaseChatModel -from langchain_core.language_models.base import LanguageModelInput -from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, SystemMessage -from langchain_core.runnables import Runnable -from langchain_core.tools import BaseTool -from langgraph.graph import END, START, StateGraph -from langgraph.graph.graph import CompiledGraph -from langgraph.prebuilt.tool_node import msg_content_output -from pydantic import BaseModel, Field, ValidationError - -from rai.agents.tool_runner import ToolRunner -from rai.messages import HumanMultimodalMessage - - -class State(TypedDict): - messages: List[BaseMessage] - - -class Report(BaseModel): - problem: str = Field(..., title="Problem", description="The problem that occurred") - steps: List[str] = Field( - ..., title="Steps", description="The steps taken to solve the problem" - ) - success: bool = Field( - ..., title="Success", description="Whether the problem was solved" - ) - outcome: str = Field( - ..., - title="Response", - description="Detailed outcome of the task, the response to the user", - ) - - -def tools_condition( - state: Union[list[AnyMessage], dict[str, Any]], -) -> Literal["tools", "reporter"]: - if isinstance(state, list): - ai_message = state[-1] - elif messages := state.get("messages", []): - ai_message = messages[-1] - else: - raise ValueError(f"No messages found in input state to tool_edge: {state}") - if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: - return "tools" - return "reporter" - - -def thinker(llm: BaseChatModel, logger: logging.Logger, state: State): - logger.info("Running thinker") - prompt = ( - "Based on the data provided, reason about the situation. " - "Analyze the context, identify any problems or challenges, " - "and consider potential implications." - ) - ai_msg = llm.invoke([SystemMessage(content=prompt)] + state["messages"]) - state["messages"].append(ai_msg) - return state - - -def decider( - llm: Runnable[LanguageModelInput, BaseMessage], logger: logging.Logger, state: State -): - logger.info("Running decider") - prompt = ( - "Based on the previous information, make a decision using tools if necessary. " - "If you are sure the problem has been solved, do not request any tools. " - "Request one tool at a time." - ) - - input = state["messages"] + [HumanMessage(prompt)] - ai_msg = llm.invoke(input) - state["messages"].append(ai_msg) - if ai_msg.tool_calls: - logger.info("Tools requested: {}".format(ai_msg.tool_calls)) - return state - - -def reporter(llm: BaseChatModel, logger: logging.Logger, state: State): - logger.info("Summarizing the conversation") - prompt = ( - "You are the reporter. Your task is to summarize what happened previously. " - "Make sure to mention the problem, solution and the outcome. Prepare clear response to the user." - ) - n_tries = 5 - ai_msg = None - for i in range(n_tries): - try: - ai_msg = llm.with_structured_output(Report).invoke( - [SystemMessage(content=prompt)] + state["messages"] - ) - break - except ValidationError: - logger.info( - f"Failed to summarize using given template. Repeating: {i}/{n_tries}" - ) - - if ai_msg is None: - logger.info("Failed to summarize. Trying without template") - ai_msg = llm.invoke([SystemMessage(content=prompt)] + state["messages"]) - - state["messages"].append(ai_msg) - return state - - -def retriever_wrapper( - state_retriever: Callable[[], Dict[str, Any]], logger: logging.Logger, state: State -): - """This wrapper is used to retrieve multimodal information from the output of state_retriever.""" - ts = time.perf_counter() - retrieved_info = state_retriever() - te = time.perf_counter() - ts - logger.info(f"Retrieved state in {te} seconds") - - images = retrieved_info.pop("images", []) - audios = retrieved_info.pop("audios", []) - - info = msg_content_output(retrieved_info) - state["messages"].append( - HumanMultimodalMessage( - content=f"Retrieved state: {info}", images=images, audios=audios - ) - ) - return state - - -def create_state_based_agent( - llm: BaseChatModel, - tools: List[BaseTool], - state_retriever: Callable[[], Dict[str, Any]], - logger: Optional[logging.Logger] = None, -) -> CompiledGraph: - _logger = None - if isinstance(logger, logging.Logger): - _logger = logger - else: - _logger = logging.getLogger(__name__) - - _logger.info("Creating state based agent") - - llm_with_tools = llm.bind_tools(tools) - tool_node = ToolRunner(tools=tools, logger=_logger) - - workflow = StateGraph(State) - workflow.add_node( - "state_retriever", partial(retriever_wrapper, state_retriever, _logger) - ) - workflow.add_node("tools", tool_node) - # workflow.add_node("thinker", partial(thinker, llm, _logger)) - workflow.add_node("decider", partial(decider, llm_with_tools, _logger)) - workflow.add_node("reporter", partial(reporter, llm, _logger)) - - workflow.add_edge(START, "state_retriever") - workflow.add_edge("state_retriever", "decider") - # workflow.add_edge("thinker", "decider") - workflow.add_edge("tools", "state_retriever") - workflow.add_edge("reporter", END) - workflow.add_conditional_edges( - "decider", - tools_condition, - ) - - app = workflow.compile() - _logger.info("State based agent created") - return app diff --git a/src/rai_core/rai/aggregators/__init__.py b/src/rai_core/rai/aggregators/__init__.py new file mode 100644 index 000000000..254f91fe3 --- /dev/null +++ b/src/rai_core/rai/aggregators/__init__.py @@ -0,0 +1,17 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseAggregator + +__all__ = ["BaseAggregator"] diff --git a/src/rai_core/rai/aggregators/base.py b/src/rai_core/rai/aggregators/base.py new file mode 100644 index 000000000..8a099ada9 --- /dev/null +++ b/src/rai_core/rai/aggregators/base.py @@ -0,0 +1,54 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections import deque +from typing import Deque, Generic, List, TypeVar + +from langchain_core.messages import BaseMessage + +T = TypeVar("T") + + +class BaseAggregator(ABC, Generic[T]): + """ + Interface for aggregators. + + `__call__` method receives a message and appends it to the buffer. + `get` method returns the aggregated message. + """ + + def __init__(self, max_size: int | None = None) -> None: + super().__init__() + self._buffer: Deque[T] = deque() + self.max_size = max_size + + def __call__(self, msg: T) -> None: + if self.max_size is not None and len(self._buffer) >= self.max_size: + self._buffer.popleft() + self._buffer.append(msg) + + @abstractmethod + def get(self) -> BaseMessage | None: + """Returns the aggregated message""" + pass + + def clear_buffer(self) -> None: + self._buffer.clear() + + def get_buffer(self) -> List[T]: + return list(self._buffer) + + def __str__(self) -> str: + return f"{self.__class__.__name__}(len={len(self._buffer)})" diff --git a/src/rai_core/rai/aggregators/ros2/__init__.py b/src/rai_core/rai/aggregators/ros2/__init__.py new file mode 100644 index 000000000..569ce09eb --- /dev/null +++ b/src/rai_core/rai/aggregators/ros2/__init__.py @@ -0,0 +1,27 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .aggregators import ( + ROS2GetLastImageAggregator, + ROS2ImgVLMDescriptionAggregator, + ROS2ImgVLMDiffAggregator, + ROS2LogsAggregator, +) + +__all__ = [ + "ROS2GetLastImageAggregator", + "ROS2ImgVLMDescriptionAggregator", + "ROS2ImgVLMDiffAggregator", + "ROS2LogsAggregator", +] diff --git a/src/rai_core/rai/aggregators/ros2/aggregators.py b/src/rai_core/rai/aggregators/ros2/aggregators.py new file mode 100644 index 000000000..61172cadf --- /dev/null +++ b/src/rai_core/rai/aggregators/ros2/aggregators.py @@ -0,0 +1,169 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, cast + +from langchain.chat_models.base import BaseChatModel +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field +from rcl_interfaces.msg import Log +from sensor_msgs.msg import CompressedImage, Image + +from rai.aggregators import BaseAggregator +from rai.communication.ros2.api import convert_ros_img_to_base64 +from rai.initialization.model_initialization import get_llm_model +from rai.messages import HumanMultimodalMessage + + +class ROS2LogsAggregator(BaseAggregator[Log]): + """Returns only unique messages while keeping their order""" + + levels = {10: "DEBUG", 20: "INFO", 30: "WARNING", 40: "ERROR", 50: "FATAL"} + + def get(self) -> HumanMessage: + msgs = self.get_buffer() + buffer = [] + prev_parsed = None + counter = 0 + for log in msgs: + level = self.levels[log.level] + parsed = f"[{log.name}] [{level}] [{log.function}] {log.msg}" + if parsed == prev_parsed: + counter += 1 + continue + else: + if counter != 0: + parsed = f"Log above repeated {counter} times" + buffer.append(parsed) + counter = 0 + prev_parsed = parsed + result = f"Logs summary: {list(dict.fromkeys(buffer))}" + self.clear_buffer() + return HumanMessage(content=result) + + +class ROS2GetLastImageAggregator(BaseAggregator[Image | CompressedImage]): + """Returns the last image from the buffer as base64 encoded string""" + + def get(self) -> HumanMultimodalMessage | None: + msgs = self.get_buffer() + if len(msgs) == 0: + return None + ros2_img = msgs[-1] + b64_image = convert_ros_img_to_base64(ros2_img) + self.clear_buffer() + return HumanMultimodalMessage(content="", images=[b64_image]) + + +class ROS2ImgVLMDescriptionAggregator(BaseAggregator[Image | CompressedImage]): + """ + Returns the VLM analysis of the last image in the aggregation buffer + """ + + SYSTEM_PROMPT = "You are an expert in image analysis and your speciality is the description of images" + + def __init__( + self, max_size: int | None = None, llm: BaseChatModel | None = None + ) -> None: + super().__init__(max_size) + if llm is None: + self.llm = get_llm_model(model_type="simple_model", streaming=True) + else: + self.llm = llm + + def get(self) -> HumanMessage | None: + msgs: List[Image | CompressedImage] = self.get_buffer() + if len(msgs) == 0: + return None + + b64_images: List[str] = [convert_ros_img_to_base64(msg) for msg in msgs] + self.clear_buffer() + + class ROS2ImgDescription(BaseModel): + key_elements: List[str] = Field( + ..., description="Key elements of the image" + ) + + task = [ + SystemMessage(content=self.SYSTEM_PROMPT), + HumanMultimodalMessage( + content="Describe key elements that are currently in robot's view", + images=[b64_images[-1]], + ), + ] + llm = self.llm.with_structured_output(ROS2ImgDescription) + response = cast(ROS2ImgDescription, llm.invoke(task)) + return HumanMessage( + content=f"These are the key elements of the last camera image frame: {response}" + ) + + +class ROS2ImgVLMDiffAggregator(BaseAggregator[Image | CompressedImage]): + """ + Returns the LLM analysis of the differences between 3 images in the + aggregation buffer: 1st, midden, last + """ + + SYSTEM_PROMPT = "You are an expert in image analysis and your speciality is the comparison of 3 images" + + def __init__( + self, max_size: int | None = None, llm: BaseChatModel | None = None + ) -> None: + super().__init__(max_size) + if llm is None: + self.llm = get_llm_model(model_type="simple_model", streaming=True) + else: + self.llm = llm + + @staticmethod + def get_key_elements(elements: List[Any]) -> List[Any]: + """ + Returns 1st, last and middle elements of the list + """ + if len(elements) <= 3: + return elements + middle_index = len(elements) // 2 + return [elements[0], elements[middle_index], elements[-1]] + + def get(self) -> HumanMessage | None: + msgs = self.get_buffer() + if len(msgs) == 0: + return None + + b64_images = [convert_ros_img_to_base64(msg) for msg in msgs] + + self.clear_buffer() + + b64_images = self.get_key_elements(b64_images) + + class ROS2ImgDiffOutput(BaseModel): + are_different: bool = Field( + ..., description="Whether the images are different" + ) + differences: List[str] = Field( + ..., description="Description of the difference" + ) + + task = [ + SystemMessage(content=self.SYSTEM_PROMPT), + HumanMultimodalMessage( + content="Here are max 3 subsequent images from the robot camera. Robot might be moving. Outline key differences in robot's view.", + images=b64_images, + ), + ] + llm = self.llm.with_structured_output(ROS2ImgDiffOutput) + response = cast(ROS2ImgDiffOutput, llm.invoke(task)) + return HumanMessage( + content=f"Result of the analysis of the {len(b64_images)} keyframes selected from {len(b64_images)} last images:\n{response}" + ) diff --git a/src/rai_core/rai/communication/ros2/api/conversion.py b/src/rai_core/rai/communication/ros2/api/conversion.py index 86cdfbea9..23dc0891f 100644 --- a/src/rai_core/rai/communication/ros2/api/conversion.py +++ b/src/rai_core/rai/communication/ros2/api/conversion.py @@ -92,20 +92,28 @@ def convert_ros_img_to_cv2mat(msg: sensor_msgs.msg.Image) -> cv2.typing.MatLike: return cv_image -def convert_ros_img_to_base64(msg: sensor_msgs.msg.Image) -> str: +def convert_ros_img_to_base64( + msg: sensor_msgs.msg.Image | sensor_msgs.msg.CompressedImage, +) -> str: bridge = CvBridge() - cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")) # type: ignore + msg_type = type(msg) + if msg_type == sensor_msgs.msg.Image: + cv_image = bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough") + elif msg_type == sensor_msgs.msg.CompressedImage: + cv_image = bridge.compressed_imgmsg_to_cv2(msg, desired_encoding="passthrough") + else: + raise ValueError(f"Unsupported message type: {msg_type}") + if cv_image.shape[-1] == 4: cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGRA2RGB) return base64.b64encode(bytes(cv2.imencode(".png", cv_image)[1])).decode( "utf-8" ) elif cv_image.shape[-1] == 1: - cv_image = cv2.cvtColor(cv_image, cv2.GRAY2RGB) + cv_image = cv2.cvtColor(cv_image, cv2.COLOR_GRAY2RGB) return base64.b64encode(bytes(cv2.imencode(".png", cv_image)[1])).decode( "utf-8" ) - else: cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) image_data = cv2.imencode(".png", cv_image)[1].tostring() # type: ignore