Skip to content

feat: add StateBaseAgent #529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/developer_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ Choose the conversational_agent when beginning with RAI, and consider upgrading
from myrobot import robot

from rai.agents.conversational_agent import create_conversational_agent
from rai.agents.state_based import create_state_based_agent
from rai.agents.langchain import create_state_based_runnable
from rai import get_llm_model

SYSTEM_PROMPT = "You are a robot with interfaces..."
Expand All @@ -144,7 +144,7 @@ tools = [pick_up_object, scan_object, SayTool(robot=robot)]
conversational_agent = create_conversational_agent(
llm=llm, tools=tools, system_prompt=SYSTEM_PROMPT
)
state_based_agent = create_state_based_agent(
state_based_agent = create_state_based_runnable(
llm=llm, state_retriever=state_retriever, tools=tools
)

Expand Down
57 changes: 57 additions & 0 deletions examples/agents/state_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language goveself.rning permissions and
# limitations under the License.

from rai.agents import wait_for_shutdown
from rai.agents.langchain import StateBasedConfig
from rai.agents.ros2 import ROS2StateBasedAgent
from rai.aggregators.ros2 import (
ROS2ImgVLMDiffAggregator,
ROS2LogsAggregator,
)
from rai.communication.ros2 import (
ROS2Connector,
ROS2Context,
ROS2HRIConnector,
)
from rai.tools.ros2.generic.toolkit import ROS2Toolkit


@ROS2Context()
def main():
hri_connector = ROS2HRIConnector()
ros2_connector = ROS2Connector()

config = StateBasedConfig(
aggregators={
("/camera/camera/color/image_raw", "sensor_msgs/msg/Image"): [
ROS2ImgVLMDiffAggregator()
],
"/rosout": [
ROS2LogsAggregator()
], # if msg_type is not provided, topic has to exist
}
)

agent = ROS2StateBasedAgent(
config=config,
target_connectors={"to_human": hri_connector},
tools=ROS2Toolkit(connector=ros2_connector).get_tools(),
)
agent.subscribe_source("/from_human", hri_connector)
agent.run()
wait_for_shutdown([agent])


if __name__ == "__main__":
main()
2 changes: 0 additions & 2 deletions src/rai_core/rai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from rai.agents.conversational_agent import create_conversational_agent
from rai.agents.langchain.react_agent import ReActAgent
from rai.agents.runner import AgentRunner, wait_for_shutdown
from rai.agents.state_based import create_state_based_agent
from rai.agents.tool_runner import ToolRunner

__all__ = [
Expand All @@ -25,6 +24,5 @@
"ReActAgent",
"ToolRunner",
"create_conversational_agent",
"create_state_based_agent",
"wait_for_shutdown",
]
22 changes: 20 additions & 2 deletions src/rai_core/rai/agents/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .agent import BaseState, LangChainAgent, newMessageBehaviorType
from .callback import HRICallbackHandler
from .runnables import create_react_runnable
from .react_agent import ReActAgent
from .runnables import (
ReActAgentState,
create_react_runnable,
create_state_based_runnable,
)
from .state_based_agent import BaseStateBasedAgent, StateBasedConfig

__all__ = ["HRICallbackHandler", "create_react_runnable"]
__all__ = [
"BaseState",
"BaseStateBasedAgent",
"HRICallbackHandler",
"LangChainAgent",
"ReActAgent",
"ReActAgentState",
"StateBasedConfig",
"create_react_runnable",
"create_state_based_runnable",
"newMessageBehaviorType",
]
9 changes: 6 additions & 3 deletions src/rai_core/rai/agents/langchain/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
from langchain_core.messages import BaseMessage
from langchain_core.runnables import Runnable

from rai.agents.base import BaseAgent
from rai.agents.langchain import HRICallbackHandler
from rai.agents.langchain.runnables import ReActAgentState
from rai.communication.hri_connector import HRIConnector, HRIMessage
from rai.initialization import get_tracing_callbacks

from ..base import BaseAgent
from .callback import HRICallbackHandler
from .runnables import ReActAgentState


class BaseState(TypedDict):
messages: List[BaseMessage]
Expand Down Expand Up @@ -208,6 +209,7 @@ def _run_loop(self):
self._run_agent()

def stop(self):
"""Stop the agent's execution loop."""
self._stop_event.set()
self._interrupt_event.set()
self._agent_ready_event.wait()
Expand All @@ -216,6 +218,7 @@ def stop(self):
self._thread.join()
self._thread = None
self.logger.info("Agent stopped")
self._stop_event.clear()

@staticmethod
def _apply_reduction_behavior(
Expand Down
6 changes: 3 additions & 3 deletions src/rai_core/rai/agents/langchain/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import BaseTool

from rai.agents.langchain import create_react_runnable
from rai.agents.langchain.agent import LangChainAgent
from rai.agents.langchain.runnables import ReActAgentState
from rai.communication.hri_connector import HRIConnector, HRIMessage

from .agent import LangChainAgent
from .runnables import ReActAgentState, create_react_runnable


class ReActAgent(LangChainAgent):
def __init__(
Expand Down
64 changes: 61 additions & 3 deletions src/rai_core/rai/agents/langchain/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from functools import partial
from typing import List, Optional, TypedDict, cast
from typing import (
Any,
Callable,
Dict,
List,
Optional,
TypedDict,
cast,
)

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langgraph.graph import START, StateGraph
from langgraph.prebuilt.tool_node import tools_condition

from rai.agents.tool_runner import ToolRunner
from rai.initialization import get_llm_model
from rai.messages import HumanMultimodalMessage

from ..tool_runner import ToolRunner


class ReActAgentState(TypedDict):
Expand Down Expand Up @@ -112,3 +123,50 @@ def create_react_runnable(

# Compile the graph
return graph.compile()


def retriever_wrapper(
state_retriever: Callable[[], Dict[str, HumanMessage | HumanMultimodalMessage]],
state: ReActAgentState,
):
"""This wrapper is used to put state messages into LLM context"""
for source, message in state_retriever().items():
message.content = f"{source}: {message.content}"
logging.getLogger("state_retriever").debug(
f"Adding state message:\n{message.pretty_repr()}"
)
state["messages"].append(message)
return state


def create_state_based_runnable(
llm: Optional[BaseChatModel] = None,
tools: Optional[List[BaseTool]] = None,
system_prompt: Optional[str] = None,
state_retriever: Optional[Callable[[], Dict[str, Any]]] = None,
) -> Runnable[ReActAgentState, ReActAgentState]:
if llm is None:
llm = get_llm_model("complex_model", streaming=True)
graph = StateGraph(ReActAgentState)
graph.add_edge(START, "state_retriever")
graph.add_edge("state_retriever", "llm")
graph.add_conditional_edges(
"llm",
tools_condition,
)
graph.add_edge("tools", "state_retriever")

if state_retriever is None:
state_retriever = lambda: {}

graph.add_node("state_retriever", partial(retriever_wrapper, state_retriever))

if tools is None:
tools = []
bound_llm = cast(BaseChatModel, llm.bind_tools(tools))
graph.add_node("llm", partial(llm_node, bound_llm, system_prompt))

tool_runner = ToolRunner(tools)
graph.add_node("tools", tool_runner)

return graph.compile()
Loading